Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ecc2489eef | |||
| 1191fa6d19 | |||
| 6298158d3c | |||
| b0cf88704e | |||
| 1ef7918291 | |||
| a6f0edd9f3 | |||
| d4e43e6cb2 | |||
| 36a412138d | |||
| 229ffdf51a | |||
| 53b6ae3a76 |
26
.gitignore
vendored
26
.gitignore
vendored
@@ -22,5 +22,27 @@ Pipfile.lock # Consider if you want to include or exclude
|
|||||||
._*
|
._*
|
||||||
*~.nib
|
*~.nib
|
||||||
*~.xib
|
*~.xib
|
||||||
README.md.old
|
|
||||||
oai.zip
|
# Claude Code local settings
|
||||||
|
.claude/
|
||||||
|
|
||||||
|
# Added by author
|
||||||
|
*.zip
|
||||||
|
.note
|
||||||
|
diagnose.py
|
||||||
|
*.log
|
||||||
|
*.xml
|
||||||
|
build*
|
||||||
|
*.spec
|
||||||
|
compiled/
|
||||||
|
images/oai-iOS-Default-1024x1024@1x.png
|
||||||
|
images/oai.icon/
|
||||||
|
b0.sh
|
||||||
|
*.bak
|
||||||
|
*.old
|
||||||
|
*.sh
|
||||||
|
*.back
|
||||||
|
requirements.txt
|
||||||
|
system_prompt.txt
|
||||||
|
CLAUDE*
|
||||||
|
SESSION*_COMPLETE.md
|
||||||
|
|||||||
436
README.md
436
README.md
@@ -1,228 +1,326 @@
|
|||||||
# oAI - OpenRouter AI Chat
|
# oAI - OpenRouter AI Chat Client
|
||||||
|
|
||||||
A terminal-based chat interface for OpenRouter API with conversation management, cost tracking, and rich formatting.
|
A powerful, modern **Textual TUI** chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases.
|
||||||
|
|
||||||
## Description
|
|
||||||
|
|
||||||
oAI is a command-line chat application that provides an interactive interface to OpenRouter's AI models. It features conversation persistence, file attachments, export capabilities, and detailed session metrics.
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- Interactive chat with multiple AI models via OpenRouter
|
### Core Features
|
||||||
- Model selection with search functionality
|
- 🖥️ **Modern Textual TUI** with async streaming and beautiful interface
|
||||||
- Conversation save/load/export (Markdown, JSON, HTML)
|
- 🤖 Interactive chat with 300+ AI models via OpenRouter
|
||||||
- File attachment support (code files and images)
|
- 🔍 Model selection with search, filtering, and capability icons
|
||||||
- Session cost tracking and credit monitoring
|
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
|
||||||
- Rich terminal formatting with syntax highlighting
|
- 📎 File attachments (images, PDFs, code files)
|
||||||
- Persistent command history
|
- 💰 Real-time cost tracking and credit monitoring
|
||||||
- Configurable system prompts and token limits
|
- 🎨 Dark theme with syntax highlighting and Markdown rendering
|
||||||
- SQLite-based configuration and conversation storage
|
- 📝 Command history navigation (Up/Down arrows)
|
||||||
|
- 🌐 Online mode (web search capabilities)
|
||||||
|
- 🧠 Conversation memory toggle
|
||||||
|
- ⌨️ Keyboard shortcuts (F1=Help, F2=Models, Ctrl+S=Stats)
|
||||||
|
|
||||||
|
### MCP Integration
|
||||||
|
- 🔧 **File Mode**: AI can read, search, and list local files
|
||||||
|
- Automatic .gitignore filtering
|
||||||
|
- Virtual environment exclusion
|
||||||
|
- Large file handling (auto-truncates >50KB)
|
||||||
|
|
||||||
|
- ✍️ **Write Mode**: AI can modify files with permission
|
||||||
|
- Create, edit, delete files
|
||||||
|
- Move, copy, organize files
|
||||||
|
- Always requires explicit opt-in
|
||||||
|
|
||||||
|
- 🗄️ **Database Mode**: AI can query SQLite databases
|
||||||
|
- Read-only access (safe)
|
||||||
|
- Schema inspection
|
||||||
|
- Full SQL query support
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- Python 3.7 or higher
|
- Python 3.10-3.13
|
||||||
- OpenRouter API key (get one at https://openrouter.ai)
|
- OpenRouter API key ([get one here](https://openrouter.ai))
|
||||||
|
|
||||||
## Screenshot
|
|
||||||
|
|
||||||
[<img src="https://gitlab.pm/rune/oai/raw/branch/main/images/screenshot_01.png">](https://gitlab.pm/rune/oai/src/branch/main/README.md)
|
|
||||||
|
|
||||||
Screenshot of `/help` screen.
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
### 1. Install Dependencies
|
### Option 1: Pre-built Binary (macOS/Linux) (Recommended)
|
||||||
|
|
||||||
Use the included `requirements.txt` file to install the dependencies:
|
Download from [Releases](https://gitlab.pm/rune/oai/releases):
|
||||||
|
- **macOS (Apple Silicon)**: `oai_v3.0.0_mac_arm64.zip`
|
||||||
|
- **Linux (x86_64)**: `oai_v3.0.0_linux_x86_64.zip`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
# Extract and install
|
||||||
```
|
unzip oai_v3.0.0_*.zip
|
||||||
|
|
||||||
### 2. Make the Script Executable
|
|
||||||
|
|
||||||
```bash
|
|
||||||
chmod +x oai.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Copy to PATH
|
|
||||||
|
|
||||||
Copy the script to a directory in your `$PATH` environment variable. Common locations include:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Option 1: System-wide (requires sudo)
|
|
||||||
sudo cp oai.py /usr/local/bin/oai
|
|
||||||
|
|
||||||
# Option 2: User-local (recommended)
|
|
||||||
mkdir -p ~/.local/bin
|
mkdir -p ~/.local/bin
|
||||||
cp oai.py ~/.local/bin/oai
|
mv oai ~/.local/bin/
|
||||||
|
|
||||||
# Add to PATH if not already (add to ~/.bashrc or ~/.zshrc)
|
# macOS only: Remove quarantine and approve
|
||||||
|
xattr -cr ~/.local/bin/oai
|
||||||
|
# Then right-click oai in Finder → Open With → Terminal → Click "Open"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Add to PATH
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Add to ~/.zshrc or ~/.bashrc
|
||||||
export PATH="$HOME/.local/bin:$PATH"
|
export PATH="$HOME/.local/bin:$PATH"
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. Verify Installation
|
|
||||||
|
### Option 2: Install from Source
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Clone the repository
|
||||||
|
git clone https://gitlab.pm/rune/oai.git
|
||||||
|
cd oai
|
||||||
|
|
||||||
|
# Install with pip
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start oAI (launches TUI)
|
||||||
oai
|
oai
|
||||||
|
|
||||||
|
# Or with options
|
||||||
|
oai --model gpt-4o --online --mcp
|
||||||
|
|
||||||
|
# Show version
|
||||||
|
oai version
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5. Alternative Installation (for *nix systems)
|
On first run, you'll be prompted for your OpenRouter API key.
|
||||||
|
|
||||||
If you have issues with the above method you can add an alias in your `.bashrc`, `.zshrc` etc.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
alias oai='python3 <path to your file>'
|
|
||||||
```
|
|
||||||
|
|
||||||
On first run, you will be prompted to enter your OpenRouter API key.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### Starting the Application
|
|
||||||
|
|
||||||
```bash
|
|
||||||
oai
|
|
||||||
```
|
|
||||||
|
|
||||||
### Basic Commands
|
### Basic Commands
|
||||||
|
|
||||||
```
|
```bash
|
||||||
/help Show all available commands
|
# In the TUI interface:
|
||||||
/model Select an AI model
|
/model # Select AI model (or press F2)
|
||||||
/config api Set OpenRouter API key
|
/help # Show all commands (or press F1)
|
||||||
exit Quit the application
|
/mcp on # Enable file/database access
|
||||||
|
/stats # View session statistics (or press Ctrl+S)
|
||||||
|
/config # View configuration settings
|
||||||
|
/credits # Check account credits
|
||||||
|
Ctrl+Q # Quit
|
||||||
```
|
```
|
||||||
|
|
||||||
### Configuration
|
## MCP (Model Context Protocol)
|
||||||
|
|
||||||
All configuration is stored in `~/.config/oai/`:
|
MCP allows the AI to interact with your local files and databases.
|
||||||
- `oai_config.db` - SQLite database for settings and conversations
|
|
||||||
- `oai.log` - Application log file
|
|
||||||
- `history.txt` - Command history
|
|
||||||
|
|
||||||
### Common Workflows
|
### File Access
|
||||||
|
|
||||||
**Select a Model:**
|
```bash
|
||||||
```
|
/mcp on # Enable MCP
|
||||||
/model
|
/mcp add ~/Projects # Grant access to folder
|
||||||
|
/mcp list # View allowed folders
|
||||||
|
|
||||||
|
# Now ask the AI:
|
||||||
|
"List all Python files in Projects"
|
||||||
|
"Read and explain main.py"
|
||||||
|
"Search for files containing 'TODO'"
|
||||||
```
|
```
|
||||||
|
|
||||||
**Paste from clipboard:**
|
### Write Mode
|
||||||
Paste and send content to model
|
|
||||||
```
|
```bash
|
||||||
/paste
|
/mcp write on # Enable file modifications
|
||||||
|
|
||||||
|
# AI can now:
|
||||||
|
"Create a new file called utils.py"
|
||||||
|
"Edit config.json and update the API URL"
|
||||||
|
"Delete the old backup files" # Always asks for confirmation
|
||||||
```
|
```
|
||||||
|
|
||||||
Paste with prompt and send content to model
|
### Database Mode
|
||||||
```
|
|
||||||
/paste Analyze this text
|
|
||||||
```
|
|
||||||
|
|
||||||
**Start Chatting:**
|
```bash
|
||||||
```
|
/mcp add db ~/app/data.db # Add database
|
||||||
You> Hello, how are you?
|
/mcp db 1 # Switch to database mode
|
||||||
```
|
|
||||||
|
|
||||||
**Attach Files:**
|
# Ask the AI:
|
||||||
|
"Show all tables"
|
||||||
|
"Find users created this month"
|
||||||
|
"What's the schema for the orders table?"
|
||||||
```
|
```
|
||||||
You> Debug this code @/path/to/script.py
|
|
||||||
You> Analyze this image @/path/to/image.png
|
|
||||||
```
|
|
||||||
|
|
||||||
**Save Conversation:**
|
|
||||||
```
|
|
||||||
/save my_conversation
|
|
||||||
```
|
|
||||||
|
|
||||||
**Export to File:**
|
|
||||||
```
|
|
||||||
/export md notes.md
|
|
||||||
/export json backup.json
|
|
||||||
/export html report.html
|
|
||||||
```
|
|
||||||
|
|
||||||
**View Session Stats:**
|
|
||||||
|
|
||||||
```
|
|
||||||
/stats
|
|
||||||
/credits
|
|
||||||
```
|
|
||||||
|
|
||||||
**Prevous commands input:**
|
|
||||||
|
|
||||||
Use the up/down arrows to see earlier `/`commands and earlier input to model and `<enter>` to execute the same command or resend the same input.
|
|
||||||
|
|
||||||
## Command Reference
|
## Command Reference
|
||||||
|
|
||||||
Use `/help` within the application for a complete command reference organized by category:
|
### Chat Commands
|
||||||
- Session Commands
|
| Command | Description |
|
||||||
- Model Commands
|
|---------|-------------|
|
||||||
- Configuration
|
| `/help [cmd]` | Show help |
|
||||||
- Token & System
|
| `/model [search]` | Select model |
|
||||||
- Conversation Management
|
| `/info [model]` | Model details |
|
||||||
- Monitoring & Stats
|
| `/memory on\|off` | Toggle context |
|
||||||
- File Attachments
|
| `/online on\|off` | Toggle web search |
|
||||||
|
| `/retry` | Resend last message |
|
||||||
|
| `/clear` | Clear screen |
|
||||||
|
|
||||||
## Configuration Options
|
### MCP Commands
|
||||||
|
| Command | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `/mcp on\|off` | Enable/disable MCP |
|
||||||
|
| `/mcp status` | Show MCP status |
|
||||||
|
| `/mcp add <path>` | Add folder |
|
||||||
|
| `/mcp add db <path>` | Add database |
|
||||||
|
| `/mcp list` | List folders |
|
||||||
|
| `/mcp db list` | List databases |
|
||||||
|
| `/mcp db <n>` | Switch to database |
|
||||||
|
| `/mcp files` | Switch to file mode |
|
||||||
|
| `/mcp write on\|off` | Toggle write mode |
|
||||||
|
|
||||||
- API Key: `/config api`
|
### Conversation Commands
|
||||||
- Base URL: `/config url`
|
| Command | Description |
|
||||||
- Streaming: `/config stream on|off`
|
|---------|-------------|
|
||||||
- Default Model: `/config model`
|
| `/save <name>` | Save conversation |
|
||||||
- Cost Warning: `/config costwarning <amount>`
|
| `/load <name>` | Load conversation |
|
||||||
- Max Token Limit: `/config maxtoken <value>`
|
| `/list` | List saved conversations |
|
||||||
|
| `/delete <name>` | Delete conversation |
|
||||||
|
| `/export md\|json\|html <file>` | Export |
|
||||||
|
|
||||||
## File Support
|
### Configuration
|
||||||
|
| Command | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `/config` | View settings |
|
||||||
|
| `/config api` | Set API key |
|
||||||
|
| `/config model <id>` | Set default model |
|
||||||
|
| `/config stream on\|off` | Toggle streaming |
|
||||||
|
| `/stats` | Session statistics |
|
||||||
|
| `/credits` | Check credits |
|
||||||
|
|
||||||
**Supported Code Extensions:**
|
## CLI Options
|
||||||
.py, .js, .ts, .cs, .java, .c, .cpp, .h, .hpp, .rb, .ruby, .php, .swift, .kt, .kts, .go, .sh, .bat, .ps1, .R, .scala, .pl, .lua, .dart, .elm, .xml, .json, .yaml, .yml, .md, .txt
|
|
||||||
|
|
||||||
**Image Support:**
|
```bash
|
||||||
Any image format with proper MIME type (PNG, JPEG, GIF, etc.)
|
oai [OPTIONS]
|
||||||
|
|
||||||
## Data Storage
|
Options:
|
||||||
|
-m, --model TEXT Model ID to use
|
||||||
|
-s, --system TEXT System prompt
|
||||||
|
-o, --online Enable online mode
|
||||||
|
--mcp Enable MCP server
|
||||||
|
-v, --version Show version
|
||||||
|
--help Show help
|
||||||
|
```
|
||||||
|
|
||||||
- Configuration: `~/.config/oai/oai_config.db`
|
Commands:
|
||||||
- Logs: `~/.config/oai/oai.log`
|
```bash
|
||||||
- History: `~/.config/oai/history.txt`
|
oai # Launch TUI (default)
|
||||||
|
oai version # Show version information
|
||||||
|
oai --help # Show help message
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Configuration is stored in `~/.config/oai/`:
|
||||||
|
|
||||||
|
| File | Purpose |
|
||||||
|
|------|---------|
|
||||||
|
| `oai_config.db` | Settings, conversations, MCP config |
|
||||||
|
| `oai.log` | Application logs |
|
||||||
|
| `history.txt` | Command history |
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
oai/
|
||||||
|
├── oai/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── __main__.py # Entry point for python -m oai
|
||||||
|
│ ├── cli.py # Main CLI entry point
|
||||||
|
│ ├── constants.py # Configuration constants
|
||||||
|
│ ├── commands/ # Slash command handlers
|
||||||
|
│ ├── config/ # Settings and database
|
||||||
|
│ ├── core/ # Chat client and session
|
||||||
|
│ ├── mcp/ # MCP server and tools
|
||||||
|
│ ├── providers/ # AI provider abstraction
|
||||||
|
│ ├── tui/ # Textual TUI interface
|
||||||
|
│ │ ├── app.py # Main TUI application
|
||||||
|
│ │ ├── widgets/ # Custom widgets
|
||||||
|
│ │ ├── screens/ # Modal screens
|
||||||
|
│ │ └── styles.tcss # TUI styling
|
||||||
|
│ └── utils/ # Logging, export, etc.
|
||||||
|
├── pyproject.toml # Package configuration
|
||||||
|
├── build.sh # Binary build script
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### macOS Binary Issues
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Remove quarantine attribute
|
||||||
|
xattr -cr ~/.local/bin/oai
|
||||||
|
|
||||||
|
# Then in Finder: right-click oai → Open With → Terminal → Click "Open"
|
||||||
|
# After this, oai works from any terminal
|
||||||
|
```
|
||||||
|
|
||||||
|
### MCP Not Working
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check if model supports function calling
|
||||||
|
/info # Look for "tools" in supported parameters
|
||||||
|
|
||||||
|
# Check MCP status
|
||||||
|
/mcp status
|
||||||
|
|
||||||
|
# View logs
|
||||||
|
tail -f ~/.config/oai/oai.log
|
||||||
|
```
|
||||||
|
|
||||||
|
### Import Errors
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Reinstall package
|
||||||
|
pip install -e . --force-reinstall
|
||||||
|
```
|
||||||
|
|
||||||
|
## Version History
|
||||||
|
|
||||||
|
### v3.0.0 (Current)
|
||||||
|
- 🎨 **Complete migration to Textual TUI** - Modern async terminal interface
|
||||||
|
- 🗑️ **Removed CLI interface** - TUI-only for cleaner codebase (11.6% smaller)
|
||||||
|
- 🖱️ **Modal screens** - Help, stats, config, credits, model selector
|
||||||
|
- ⌨️ **Keyboard shortcuts** - F1 (help), F2 (models), Ctrl+S (stats), etc.
|
||||||
|
- 🎯 **Capability indicators** - Visual icons for model features (vision, tools, online)
|
||||||
|
- 🎨 **Consistent dark theme** - Professional styling throughout
|
||||||
|
- 📊 **Enhanced model selector** - Search, filter, capability columns
|
||||||
|
- 🚀 **Default command** - Just run `oai` to launch TUI
|
||||||
|
- 🧹 **Code cleanup** - Removed 1,300+ lines of CLI code
|
||||||
|
|
||||||
|
### v2.1.0
|
||||||
|
- 🏗️ Complete codebase refactoring to modular package structure
|
||||||
|
- 🔌 Extensible provider architecture for adding new AI providers
|
||||||
|
- 📦 Proper Python packaging with pyproject.toml
|
||||||
|
- ✨ MCP integration (file access, write mode, database queries)
|
||||||
|
- 🔧 Command registry pattern for slash commands
|
||||||
|
- 📊 Improved cost tracking and session statistics
|
||||||
|
|
||||||
|
### v1.9.x
|
||||||
|
- Single-file implementation
|
||||||
|
- Core chat functionality
|
||||||
|
- File attachments
|
||||||
|
- Conversation management
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
MIT License
|
MIT License - See [LICENSE](LICENSE) for details.
|
||||||
|
|
||||||
Copyright (c) 2024 Rune Olsen
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
Full license: https://opensource.org/licenses/MIT
|
|
||||||
|
|
||||||
## Author
|
## Author
|
||||||
|
|
||||||
**Rune Olsen**
|
**Rune Olsen**
|
||||||
|
- Project: https://iurl.no/oai
|
||||||
|
- Repository: https://gitlab.pm/rune/oai
|
||||||
|
|
||||||
Blog: https://blog.rune.pm
|
## Contributing
|
||||||
|
|
||||||
## Version
|
1. Fork the repository
|
||||||
|
2. Create a feature branch
|
||||||
|
3. Submit a pull request
|
||||||
|
|
||||||
1.0
|
---
|
||||||
|
|
||||||
## Support
|
**⭐ Star this project if you find it useful!**
|
||||||
|
|
||||||
For issues, questions, or contributions, visit https://iurl.no/oai and create an issue.
|
|
||||||
|
|||||||
26
oai/__init__.py
Normal file
26
oai/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
oAI - OpenRouter AI Chat Client
|
||||||
|
|
||||||
|
A feature-rich terminal-based chat application that provides an interactive CLI
|
||||||
|
interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP)
|
||||||
|
integration for filesystem and database access.
|
||||||
|
|
||||||
|
Author: Rune
|
||||||
|
License: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "3.0.0-b2"
|
||||||
|
__author__ = "Rune"
|
||||||
|
__license__ = "MIT"
|
||||||
|
|
||||||
|
# Lazy imports to avoid circular dependencies and improve startup time
|
||||||
|
# Full imports are available via submodules:
|
||||||
|
# from oai.config import Settings, Database
|
||||||
|
# from oai.providers import OpenRouterProvider, AIProvider
|
||||||
|
# from oai.mcp import MCPManager
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"__version__",
|
||||||
|
"__author__",
|
||||||
|
"__license__",
|
||||||
|
]
|
||||||
8
oai/__main__.py
Normal file
8
oai/__main__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
Entry point for running oAI as a module: python -m oai
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.cli import main
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
199
oai/cli.py
Normal file
199
oai/cli.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
"""
|
||||||
|
Main CLI entry point for oAI.
|
||||||
|
|
||||||
|
This module provides the command-line interface for the oAI TUI application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from oai import __version__
|
||||||
|
from oai.commands import register_all_commands
|
||||||
|
from oai.config.settings import Settings
|
||||||
|
from oai.constants import APP_URL, APP_VERSION
|
||||||
|
from oai.core.client import AIClient
|
||||||
|
from oai.core.session import ChatSession
|
||||||
|
from oai.mcp.manager import MCPManager
|
||||||
|
from oai.utils.logging import LoggingManager, get_logger
|
||||||
|
|
||||||
|
# Create Typer app
|
||||||
|
app = typer.Typer(
|
||||||
|
name="oai",
|
||||||
|
help=f"oAI - OpenRouter AI Chat Client (TUI)\n\nVersion: {APP_VERSION}",
|
||||||
|
add_completion=False,
|
||||||
|
epilog="For more information, visit: " + APP_URL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.callback(invoke_without_command=True)
|
||||||
|
def main_callback(
|
||||||
|
ctx: typer.Context,
|
||||||
|
version_flag: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--version",
|
||||||
|
"-v",
|
||||||
|
help="Show version information",
|
||||||
|
is_flag=True,
|
||||||
|
),
|
||||||
|
model: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--model",
|
||||||
|
"-m",
|
||||||
|
help="Model ID to use",
|
||||||
|
),
|
||||||
|
system: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--system",
|
||||||
|
"-s",
|
||||||
|
help="System prompt",
|
||||||
|
),
|
||||||
|
online: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--online",
|
||||||
|
"-o",
|
||||||
|
help="Enable online mode",
|
||||||
|
),
|
||||||
|
mcp: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--mcp",
|
||||||
|
help="Enable MCP server",
|
||||||
|
),
|
||||||
|
) -> None:
|
||||||
|
"""Main callback - launches TUI by default."""
|
||||||
|
if version_flag:
|
||||||
|
typer.echo(f"oAI version {APP_VERSION}")
|
||||||
|
raise typer.Exit()
|
||||||
|
|
||||||
|
# If no subcommand provided, launch TUI
|
||||||
|
if ctx.invoked_subcommand is None:
|
||||||
|
_launch_tui(model, system, online, mcp)
|
||||||
|
|
||||||
|
|
||||||
|
def _launch_tui(
|
||||||
|
model: Optional[str] = None,
|
||||||
|
system: Optional[str] = None,
|
||||||
|
online: bool = False,
|
||||||
|
mcp: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Launch the Textual TUI interface."""
|
||||||
|
# Setup logging
|
||||||
|
logging_manager = LoggingManager()
|
||||||
|
logging_manager.setup()
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
# Load settings
|
||||||
|
settings = Settings.load()
|
||||||
|
|
||||||
|
# Check API key
|
||||||
|
if not settings.api_key:
|
||||||
|
typer.echo("Error: No API key configured", err=True)
|
||||||
|
typer.echo("Run: oai config api to set your API key", err=True)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Initialize client
|
||||||
|
try:
|
||||||
|
client = AIClient(
|
||||||
|
api_key=settings.api_key,
|
||||||
|
base_url=settings.base_url,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
typer.echo(f"Error: Failed to initialize client: {e}", err=True)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Register commands
|
||||||
|
register_all_commands()
|
||||||
|
|
||||||
|
# Initialize MCP manager (always create it, even if not enabled)
|
||||||
|
mcp_manager = MCPManager()
|
||||||
|
if mcp:
|
||||||
|
try:
|
||||||
|
result = mcp_manager.enable()
|
||||||
|
if result["success"]:
|
||||||
|
logger.info("MCP server enabled in files mode")
|
||||||
|
else:
|
||||||
|
logger.warning(f"MCP: {result.get('error', 'Failed to enable')}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to enable MCP: {e}")
|
||||||
|
|
||||||
|
# Create session with MCP manager
|
||||||
|
session = ChatSession(
|
||||||
|
client=client,
|
||||||
|
settings=settings,
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set system prompt if provided
|
||||||
|
if system:
|
||||||
|
session.set_system_prompt(system)
|
||||||
|
|
||||||
|
# Enable online mode if requested
|
||||||
|
if online:
|
||||||
|
session.online_enabled = True
|
||||||
|
|
||||||
|
# Set model if specified, otherwise use default
|
||||||
|
if model:
|
||||||
|
raw_model = client.get_raw_model(model)
|
||||||
|
if raw_model:
|
||||||
|
session.set_model(raw_model)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Model '{model}' not found")
|
||||||
|
elif settings.default_model:
|
||||||
|
raw_model = client.get_raw_model(settings.default_model)
|
||||||
|
if raw_model:
|
||||||
|
session.set_model(raw_model)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Default model '{settings.default_model}' not available")
|
||||||
|
|
||||||
|
# Run Textual app
|
||||||
|
from oai.tui.app import oAIChatApp
|
||||||
|
|
||||||
|
app_instance = oAIChatApp(session, settings, model)
|
||||||
|
app_instance.run()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def tui(
|
||||||
|
model: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--model",
|
||||||
|
"-m",
|
||||||
|
help="Model ID to use",
|
||||||
|
),
|
||||||
|
system: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--system",
|
||||||
|
"-s",
|
||||||
|
help="System prompt",
|
||||||
|
),
|
||||||
|
online: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--online",
|
||||||
|
"-o",
|
||||||
|
help="Enable online mode",
|
||||||
|
),
|
||||||
|
mcp: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--mcp",
|
||||||
|
help="Enable MCP server",
|
||||||
|
),
|
||||||
|
) -> None:
|
||||||
|
"""Start Textual TUI interface (alias for just running 'oai')."""
|
||||||
|
_launch_tui(model, system, online, mcp)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def version() -> None:
|
||||||
|
"""Show version information."""
|
||||||
|
typer.echo(f"oAI version {APP_VERSION}")
|
||||||
|
typer.echo(f"Visit {APP_URL} for more information")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Entry point for the CLI."""
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
24
oai/commands/__init__.py
Normal file
24
oai/commands/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
Command system for oAI.
|
||||||
|
|
||||||
|
This module provides a command registry and handler system
|
||||||
|
for processing slash commands in the chat interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.commands.registry import (
|
||||||
|
Command,
|
||||||
|
CommandRegistry,
|
||||||
|
CommandContext,
|
||||||
|
CommandResult,
|
||||||
|
registry,
|
||||||
|
)
|
||||||
|
from oai.commands.handlers import register_all_commands
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Command",
|
||||||
|
"CommandRegistry",
|
||||||
|
"CommandContext",
|
||||||
|
"CommandResult",
|
||||||
|
"registry",
|
||||||
|
"register_all_commands",
|
||||||
|
]
|
||||||
1479
oai/commands/handlers.py
Normal file
1479
oai/commands/handlers.py
Normal file
File diff suppressed because it is too large
Load Diff
382
oai/commands/registry.py
Normal file
382
oai/commands/registry.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
"""
|
||||||
|
Command registry for oAI.
|
||||||
|
|
||||||
|
This module defines the command system infrastructure including
|
||||||
|
the Command base class, CommandContext for state, and CommandRegistry
|
||||||
|
for managing available commands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from oai.config.settings import Settings
|
||||||
|
from oai.providers.base import AIProvider, ModelInfo
|
||||||
|
from oai.mcp.manager import MCPManager
|
||||||
|
|
||||||
|
|
||||||
|
class CommandStatus(str, Enum):
|
||||||
|
"""Status of command execution."""
|
||||||
|
|
||||||
|
SUCCESS = "success"
|
||||||
|
ERROR = "error"
|
||||||
|
CONTINUE = "continue" # Continue to next handler
|
||||||
|
EXIT = "exit" # Exit the application
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandResult:
|
||||||
|
"""
|
||||||
|
Result of a command execution.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
status: Execution status
|
||||||
|
message: Optional message to display
|
||||||
|
data: Optional data payload
|
||||||
|
should_continue: Whether to continue the main loop
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: CommandStatus = CommandStatus.SUCCESS
|
||||||
|
message: Optional[str] = None
|
||||||
|
data: Optional[Any] = None
|
||||||
|
should_continue: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def success(cls, message: Optional[str] = None, data: Any = None) -> "CommandResult":
|
||||||
|
"""Create a success result."""
|
||||||
|
return cls(status=CommandStatus.SUCCESS, message=message, data=data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error(cls, message: str) -> "CommandResult":
|
||||||
|
"""Create an error result."""
|
||||||
|
return cls(status=CommandStatus.ERROR, message=message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def exit(cls, message: Optional[str] = None) -> "CommandResult":
|
||||||
|
"""Create an exit result."""
|
||||||
|
return cls(status=CommandStatus.EXIT, message=message, should_continue=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandContext:
|
||||||
|
"""
|
||||||
|
Context object providing state to command handlers.
|
||||||
|
|
||||||
|
Contains all the session state needed by commands including
|
||||||
|
settings, provider, conversation history, and MCP manager.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
settings: Application settings
|
||||||
|
provider: AI provider instance
|
||||||
|
mcp_manager: MCP manager instance
|
||||||
|
selected_model: Currently selected model
|
||||||
|
session_history: Conversation history
|
||||||
|
session_system_prompt: Current system prompt
|
||||||
|
memory_enabled: Whether memory is enabled
|
||||||
|
online_enabled: Whether online mode is enabled
|
||||||
|
session_tokens: Session token counts
|
||||||
|
session_cost: Session cost total
|
||||||
|
"""
|
||||||
|
|
||||||
|
settings: Optional["Settings"] = None
|
||||||
|
provider: Optional["AIProvider"] = None
|
||||||
|
mcp_manager: Optional["MCPManager"] = None
|
||||||
|
selected_model: Optional["ModelInfo"] = None
|
||||||
|
selected_model_raw: Optional[Dict[str, Any]] = None
|
||||||
|
session_history: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
session_system_prompt: str = ""
|
||||||
|
memory_enabled: bool = True
|
||||||
|
memory_start_index: int = 0
|
||||||
|
online_enabled: bool = False
|
||||||
|
middle_out_enabled: bool = False
|
||||||
|
session_max_token: int = 0
|
||||||
|
total_input_tokens: int = 0
|
||||||
|
total_output_tokens: int = 0
|
||||||
|
total_cost: float = 0.0
|
||||||
|
message_count: int = 0
|
||||||
|
is_tui: bool = False # Flag for TUI mode
|
||||||
|
current_index: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandHelp:
|
||||||
|
"""
|
||||||
|
Help information for a command.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
description: Brief description
|
||||||
|
usage: Usage syntax
|
||||||
|
examples: List of (description, example) tuples
|
||||||
|
notes: Additional notes
|
||||||
|
aliases: Command aliases
|
||||||
|
"""
|
||||||
|
|
||||||
|
description: str
|
||||||
|
usage: str = ""
|
||||||
|
examples: List[tuple] = field(default_factory=list)
|
||||||
|
notes: str = ""
|
||||||
|
aliases: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class Command(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for all commands.
|
||||||
|
|
||||||
|
Commands implement the execute method to handle their logic.
|
||||||
|
They can also provide help information and aliases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get the primary command name (e.g., '/help')."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def aliases(self) -> List[str]:
|
||||||
|
"""Get command aliases (e.g., ['/h'] for help)."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def help(self) -> CommandHelp:
|
||||||
|
"""Get command help information."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute(self, args: str, context: CommandContext) -> CommandResult:
|
||||||
|
"""
|
||||||
|
Execute the command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Arguments passed to the command
|
||||||
|
context: Command execution context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CommandResult indicating success/failure
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def matches(self, input_text: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this command matches the input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_text: User input text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this command should handle the input
|
||||||
|
"""
|
||||||
|
input_lower = input_text.lower()
|
||||||
|
cmd_word = input_lower.split()[0] if input_lower.split() else ""
|
||||||
|
|
||||||
|
# Check primary name
|
||||||
|
if cmd_word == self.name.lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check aliases
|
||||||
|
for alias in self.aliases:
|
||||||
|
if cmd_word == alias.lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_args(self, input_text: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract arguments from the input text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_text: Full user input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Arguments portion of the input
|
||||||
|
"""
|
||||||
|
parts = input_text.split(maxsplit=1)
|
||||||
|
return parts[1] if len(parts) > 1 else ""
|
||||||
|
|
||||||
|
|
||||||
|
class CommandRegistry:
|
||||||
|
"""
|
||||||
|
Registry for managing available commands.
|
||||||
|
|
||||||
|
Provides registration, lookup, and execution of commands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize an empty command registry."""
|
||||||
|
self._commands: Dict[str, Command] = {}
|
||||||
|
self._aliases: Dict[str, str] = {}
|
||||||
|
self.logger = get_logger()
|
||||||
|
|
||||||
|
def register(self, command: Command) -> None:
|
||||||
|
"""
|
||||||
|
Register a command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: Command instance to register
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If command name already registered
|
||||||
|
"""
|
||||||
|
name = command.name.lower()
|
||||||
|
|
||||||
|
if name in self._commands:
|
||||||
|
raise ValueError(f"Command '{name}' already registered")
|
||||||
|
|
||||||
|
self._commands[name] = command
|
||||||
|
|
||||||
|
# Register aliases
|
||||||
|
for alias in command.aliases:
|
||||||
|
alias_lower = alias.lower()
|
||||||
|
if alias_lower in self._aliases:
|
||||||
|
self.logger.warning(
|
||||||
|
f"Alias '{alias}' already registered, overwriting"
|
||||||
|
)
|
||||||
|
self._aliases[alias_lower] = name
|
||||||
|
|
||||||
|
self.logger.debug(f"Registered command: {name}")
|
||||||
|
|
||||||
|
def register_function(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
handler: Callable[[str, CommandContext], CommandResult],
|
||||||
|
description: str,
|
||||||
|
usage: str = "",
|
||||||
|
aliases: Optional[List[str]] = None,
|
||||||
|
examples: Optional[List[tuple]] = None,
|
||||||
|
notes: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a function-based command.
|
||||||
|
|
||||||
|
Convenience method for simple commands that don't need
|
||||||
|
a full Command class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Command name (e.g., '/help')
|
||||||
|
handler: Function to execute
|
||||||
|
description: Help description
|
||||||
|
usage: Usage syntax
|
||||||
|
aliases: Command aliases
|
||||||
|
examples: Example usages
|
||||||
|
notes: Additional notes
|
||||||
|
"""
|
||||||
|
aliases = aliases or []
|
||||||
|
examples = examples or []
|
||||||
|
|
||||||
|
class FunctionCommand(Command):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def aliases(self) -> List[str]:
|
||||||
|
return aliases
|
||||||
|
|
||||||
|
@property
|
||||||
|
def help(self) -> CommandHelp:
|
||||||
|
return CommandHelp(
|
||||||
|
description=description,
|
||||||
|
usage=usage,
|
||||||
|
examples=examples,
|
||||||
|
notes=notes,
|
||||||
|
aliases=aliases,
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(self, args: str, context: CommandContext) -> CommandResult:
|
||||||
|
return handler(args, context)
|
||||||
|
|
||||||
|
self.register(FunctionCommand())
|
||||||
|
|
||||||
|
def get(self, name: str) -> Optional[Command]:
|
||||||
|
"""
|
||||||
|
Get a command by name or alias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Command name or alias
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Command instance or None if not found
|
||||||
|
"""
|
||||||
|
name_lower = name.lower()
|
||||||
|
|
||||||
|
# Check direct match
|
||||||
|
if name_lower in self._commands:
|
||||||
|
return self._commands[name_lower]
|
||||||
|
|
||||||
|
# Check aliases
|
||||||
|
if name_lower in self._aliases:
|
||||||
|
return self._commands[self._aliases[name_lower]]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find(self, input_text: str) -> Optional[Command]:
|
||||||
|
"""
|
||||||
|
Find a command that matches the input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_text: User input text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Matching Command or None
|
||||||
|
"""
|
||||||
|
cmd_word = input_text.lower().split()[0] if input_text.split() else ""
|
||||||
|
return self.get(cmd_word)
|
||||||
|
|
||||||
|
def execute(self, input_text: str, context: CommandContext) -> Optional[CommandResult]:
|
||||||
|
"""
|
||||||
|
Execute a command matching the input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_text: User input text
|
||||||
|
context: Execution context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CommandResult or None if no matching command
|
||||||
|
"""
|
||||||
|
command = self.find(input_text)
|
||||||
|
if command:
|
||||||
|
args = command.get_args(input_text)
|
||||||
|
self.logger.debug(f"Executing command: {command.name} with args: {args}")
|
||||||
|
return command.execute(args, context)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_command(self, input_text: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if input is a valid command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_text: User input text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if input matches a registered command
|
||||||
|
"""
|
||||||
|
return self.find(input_text) is not None
|
||||||
|
|
||||||
|
def list_commands(self) -> List[Command]:
|
||||||
|
"""
|
||||||
|
Get all registered commands.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Command instances
|
||||||
|
"""
|
||||||
|
return list(self._commands.values())
|
||||||
|
|
||||||
|
def get_all_names(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get all command names and aliases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of command names including aliases
|
||||||
|
"""
|
||||||
|
names = list(self._commands.keys())
|
||||||
|
names.extend(self._aliases.keys())
|
||||||
|
return sorted(set(names))
|
||||||
|
|
||||||
|
|
||||||
|
# Global registry instance
|
||||||
|
registry = CommandRegistry()
|
||||||
11
oai/config/__init__.py
Normal file
11
oai/config/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
Configuration management for oAI.
|
||||||
|
|
||||||
|
This package handles all configuration persistence, settings management,
|
||||||
|
and database operations for the application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.config.settings import Settings
|
||||||
|
from oai.config.database import Database
|
||||||
|
|
||||||
|
__all__ = ["Settings", "Database"]
|
||||||
472
oai/config/database.py
Normal file
472
oai/config/database.py
Normal file
@@ -0,0 +1,472 @@
|
|||||||
|
"""
|
||||||
|
Database persistence layer for oAI.
|
||||||
|
|
||||||
|
This module provides a clean abstraction for SQLite operations including
|
||||||
|
configuration storage, conversation persistence, and MCP statistics tracking.
|
||||||
|
All database operations are centralized here for maintainability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from oai.constants import DATABASE_FILE, CONFIG_DIR
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""
|
||||||
|
SQLite database manager for oAI.
|
||||||
|
|
||||||
|
Handles all database operations including:
|
||||||
|
- Configuration key-value storage
|
||||||
|
- Conversation session persistence
|
||||||
|
- MCP configuration and statistics
|
||||||
|
- Database registrations for MCP
|
||||||
|
|
||||||
|
Uses context managers for safe connection handling and supports
|
||||||
|
automatic table creation on first use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Optional[Path] = None):
|
||||||
|
"""
|
||||||
|
Initialize the database manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Optional custom database path. Defaults to standard location.
|
||||||
|
"""
|
||||||
|
self.db_path = db_path or DATABASE_FILE
|
||||||
|
self._ensure_directories()
|
||||||
|
self._ensure_tables()
|
||||||
|
|
||||||
|
def _ensure_directories(self) -> None:
|
||||||
|
"""Ensure the configuration directory exists."""
|
||||||
|
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _connection(self):
|
||||||
|
"""
|
||||||
|
Context manager for database connections.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
sqlite3.Connection: Active database connection
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute("SELECT * FROM config")
|
||||||
|
"""
|
||||||
|
conn = sqlite3.connect(str(self.db_path))
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
conn.commit()
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _ensure_tables(self) -> None:
|
||||||
|
"""Create all required tables if they don't exist."""
|
||||||
|
with self._connection() as conn:
|
||||||
|
# Main configuration table
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS config (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Conversation sessions table
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS conversation_sessions (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
data TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# MCP configuration table
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS mcp_config (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# MCP statistics table
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS mcp_stats (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
tool_name TEXT NOT NULL,
|
||||||
|
folder TEXT,
|
||||||
|
success INTEGER NOT NULL,
|
||||||
|
error_message TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# MCP databases table
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS mcp_databases (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
path TEXT NOT NULL UNIQUE,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
size INTEGER,
|
||||||
|
tables TEXT,
|
||||||
|
added_timestamp TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# CONFIGURATION METHODS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_config(self, key: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieve a configuration value by key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The configuration key to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configuration value, or None if not found
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT value FROM config WHERE key = ?",
|
||||||
|
(key,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
|
def set_config(self, key: str, value: str) -> None:
|
||||||
|
"""
|
||||||
|
Set a configuration value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The configuration key
|
||||||
|
value: The value to store
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)",
|
||||||
|
(key, value)
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_config(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a configuration value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The configuration key to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a row was deleted, False otherwise
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"DELETE FROM config WHERE key = ?",
|
||||||
|
(key,)
|
||||||
|
)
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
def get_all_config(self) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Retrieve all configuration values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of all key-value pairs
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute("SELECT key, value FROM config")
|
||||||
|
return dict(cursor.fetchall())
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# MCP CONFIGURATION METHODS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_mcp_config(self, key: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieve an MCP configuration value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The MCP configuration key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configuration value, or None if not found
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT value FROM mcp_config WHERE key = ?",
|
||||||
|
(key,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
|
def set_mcp_config(self, key: str, value: str) -> None:
|
||||||
|
"""
|
||||||
|
Set an MCP configuration value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The MCP configuration key
|
||||||
|
value: The value to store
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)",
|
||||||
|
(key, value)
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# MCP STATISTICS METHODS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def log_mcp_stat(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
folder: Optional[str],
|
||||||
|
success: bool,
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log an MCP tool usage event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the MCP tool that was called
|
||||||
|
folder: The folder path involved (if any)
|
||||||
|
success: Whether the call succeeded
|
||||||
|
error_message: Error message if the call failed
|
||||||
|
"""
|
||||||
|
timestamp = datetime.datetime.now().isoformat()
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO mcp_stats
|
||||||
|
(timestamp, tool_name, folder, success, error_message)
|
||||||
|
VALUES (?, ?, ?, ?, ?)""",
|
||||||
|
(timestamp, tool_name, folder, 1 if success else 0, error_message)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_mcp_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get aggregated MCP usage statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing usage statistics:
|
||||||
|
- total_calls: Total number of tool calls
|
||||||
|
- reads: Number of file reads
|
||||||
|
- lists: Number of directory listings
|
||||||
|
- searches: Number of file searches
|
||||||
|
- db_inspects: Number of database inspections
|
||||||
|
- db_searches: Number of database searches
|
||||||
|
- db_queries: Number of database queries
|
||||||
|
- last_used: Timestamp of last usage
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute("""
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total_calls,
|
||||||
|
SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads,
|
||||||
|
SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists,
|
||||||
|
SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches,
|
||||||
|
SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects,
|
||||||
|
SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches,
|
||||||
|
SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries,
|
||||||
|
MAX(timestamp) as last_used
|
||||||
|
FROM mcp_stats
|
||||||
|
""")
|
||||||
|
row = cursor.fetchone()
|
||||||
|
return {
|
||||||
|
"total_calls": row[0] or 0,
|
||||||
|
"reads": row[1] or 0,
|
||||||
|
"lists": row[2] or 0,
|
||||||
|
"searches": row[3] or 0,
|
||||||
|
"db_inspects": row[4] or 0,
|
||||||
|
"db_searches": row[5] or 0,
|
||||||
|
"db_queries": row[6] or 0,
|
||||||
|
"last_used": row[7],
|
||||||
|
}
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# MCP DATABASE REGISTRY METHODS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def add_mcp_database(self, db_info: Dict[str, Any]) -> int:
|
||||||
|
"""
|
||||||
|
Register a database for MCP access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_info: Dictionary containing:
|
||||||
|
- path: Database file path
|
||||||
|
- name: Display name
|
||||||
|
- size: File size in bytes
|
||||||
|
- tables: List of table names
|
||||||
|
- added: Timestamp when added
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The database ID
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO mcp_databases
|
||||||
|
(path, name, size, tables, added_timestamp)
|
||||||
|
VALUES (?, ?, ?, ?, ?)""",
|
||||||
|
(
|
||||||
|
db_info["path"],
|
||||||
|
db_info["name"],
|
||||||
|
db_info["size"],
|
||||||
|
json.dumps(db_info["tables"]),
|
||||||
|
db_info["added"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT id FROM mcp_databases WHERE path = ?",
|
||||||
|
(db_info["path"],)
|
||||||
|
)
|
||||||
|
return cursor.fetchone()[0]
|
||||||
|
|
||||||
|
def remove_mcp_database(self, db_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a database from the MCP registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the database file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a row was deleted, False otherwise
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"DELETE FROM mcp_databases WHERE path = ?",
|
||||||
|
(db_path,)
|
||||||
|
)
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
def get_mcp_databases(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Retrieve all registered MCP databases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of database information dictionaries
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"""SELECT id, path, name, size, tables, added_timestamp
|
||||||
|
FROM mcp_databases ORDER BY id"""
|
||||||
|
)
|
||||||
|
databases = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
tables_list = json.loads(row[4]) if row[4] else []
|
||||||
|
databases.append({
|
||||||
|
"id": row[0],
|
||||||
|
"path": row[1],
|
||||||
|
"name": row[2],
|
||||||
|
"size": row[3],
|
||||||
|
"tables": tables_list,
|
||||||
|
"added": row[5],
|
||||||
|
})
|
||||||
|
return databases
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# CONVERSATION METHODS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def save_conversation(self, name: str, data: List[Dict[str, str]]) -> None:
|
||||||
|
"""
|
||||||
|
Save a conversation session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name/identifier for the conversation
|
||||||
|
data: List of message dictionaries with 'prompt' and 'response' keys
|
||||||
|
"""
|
||||||
|
timestamp = datetime.datetime.now().isoformat()
|
||||||
|
data_json = json.dumps(data)
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO conversation_sessions
|
||||||
|
(name, timestamp, data) VALUES (?, ?, ?)""",
|
||||||
|
(name, timestamp, data_json)
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_conversation(self, name: str) -> Optional[List[Dict[str, str]]]:
|
||||||
|
"""
|
||||||
|
Load a conversation by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the conversation to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages, or None if not found
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"""SELECT data FROM conversation_sessions
|
||||||
|
WHERE name = ?
|
||||||
|
ORDER BY timestamp DESC LIMIT 1""",
|
||||||
|
(name,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
return json.loads(result[0])
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_conversation(self, name: str) -> int:
|
||||||
|
"""
|
||||||
|
Delete a conversation by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the conversation to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of rows deleted
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"DELETE FROM conversation_sessions WHERE name = ?",
|
||||||
|
(name,)
|
||||||
|
)
|
||||||
|
return cursor.rowcount
|
||||||
|
|
||||||
|
def list_conversations(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all saved conversations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of conversation summaries with name, timestamp, and message_count
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute("""
|
||||||
|
SELECT name, MAX(timestamp) as last_saved, data
|
||||||
|
FROM conversation_sessions
|
||||||
|
GROUP BY name
|
||||||
|
ORDER BY last_saved DESC
|
||||||
|
""")
|
||||||
|
conversations = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
name, timestamp, data_json = row
|
||||||
|
data = json.loads(data_json)
|
||||||
|
conversations.append({
|
||||||
|
"name": name,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"message_count": len(data),
|
||||||
|
})
|
||||||
|
return conversations
|
||||||
|
|
||||||
|
|
||||||
|
# Global database instance for convenience
|
||||||
|
_db: Optional[Database] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_database() -> Database:
|
||||||
|
"""
|
||||||
|
Get the global database instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The shared Database instance
|
||||||
|
"""
|
||||||
|
global _db
|
||||||
|
if _db is None:
|
||||||
|
_db = Database()
|
||||||
|
return _db
|
||||||
361
oai/config/settings.py
Normal file
361
oai/config/settings.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
"""
|
||||||
|
Settings management for oAI.
|
||||||
|
|
||||||
|
This module provides a centralized settings class that handles all application
|
||||||
|
configuration with type safety, validation, and persistence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from oai.constants import (
|
||||||
|
DEFAULT_BASE_URL,
|
||||||
|
DEFAULT_STREAM_ENABLED,
|
||||||
|
DEFAULT_MAX_TOKENS,
|
||||||
|
DEFAULT_ONLINE_MODE,
|
||||||
|
DEFAULT_COST_WARNING_THRESHOLD,
|
||||||
|
DEFAULT_LOG_MAX_SIZE_MB,
|
||||||
|
DEFAULT_LOG_BACKUP_COUNT,
|
||||||
|
DEFAULT_LOG_LEVEL,
|
||||||
|
DEFAULT_SYSTEM_PROMPT,
|
||||||
|
VALID_LOG_LEVELS,
|
||||||
|
)
|
||||||
|
from oai.config.database import get_database
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Settings:
|
||||||
|
"""
|
||||||
|
Application settings with persistence support.
|
||||||
|
|
||||||
|
This class provides a clean interface for managing all configuration
|
||||||
|
options. Settings are automatically loaded from the database on
|
||||||
|
initialization and can be persisted back.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
api_key: OpenRouter API key
|
||||||
|
base_url: API base URL
|
||||||
|
default_model: Default model ID to use
|
||||||
|
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
|
||||||
|
stream_enabled: Whether to stream responses
|
||||||
|
max_tokens: Maximum tokens per request
|
||||||
|
cost_warning_threshold: Alert threshold for message cost
|
||||||
|
default_online_mode: Whether online mode is enabled by default
|
||||||
|
log_max_size_mb: Maximum log file size in MB
|
||||||
|
log_backup_count: Number of log file backups to keep
|
||||||
|
log_level: Logging level (debug/info/warning/error/critical)
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
base_url: str = DEFAULT_BASE_URL
|
||||||
|
default_model: Optional[str] = None
|
||||||
|
default_system_prompt: Optional[str] = None
|
||||||
|
stream_enabled: bool = DEFAULT_STREAM_ENABLED
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS
|
||||||
|
cost_warning_threshold: float = DEFAULT_COST_WARNING_THRESHOLD
|
||||||
|
default_online_mode: bool = DEFAULT_ONLINE_MODE
|
||||||
|
log_max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
|
||||||
|
log_backup_count: int = DEFAULT_LOG_BACKUP_COUNT
|
||||||
|
log_level: str = DEFAULT_LOG_LEVEL
|
||||||
|
|
||||||
|
@property
|
||||||
|
def effective_system_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the effective system prompt to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The custom prompt if set, hardcoded default if None, or blank if explicitly set to ""
|
||||||
|
"""
|
||||||
|
if self.default_system_prompt is None:
|
||||||
|
return DEFAULT_SYSTEM_PROMPT
|
||||||
|
return self.default_system_prompt
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Validate settings after initialization."""
|
||||||
|
self._validate()
|
||||||
|
|
||||||
|
def _validate(self) -> None:
|
||||||
|
"""Validate all settings values."""
|
||||||
|
# Validate log level
|
||||||
|
if self.log_level.lower() not in VALID_LOG_LEVELS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid log level: {self.log_level}. "
|
||||||
|
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate numeric bounds
|
||||||
|
if self.max_tokens < 1:
|
||||||
|
raise ValueError("max_tokens must be at least 1")
|
||||||
|
|
||||||
|
if self.cost_warning_threshold < 0:
|
||||||
|
raise ValueError("cost_warning_threshold must be non-negative")
|
||||||
|
|
||||||
|
if self.log_max_size_mb < 1:
|
||||||
|
raise ValueError("log_max_size_mb must be at least 1")
|
||||||
|
|
||||||
|
if self.log_backup_count < 0:
|
||||||
|
raise ValueError("log_backup_count must be non-negative")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls) -> "Settings":
|
||||||
|
"""
|
||||||
|
Load settings from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Settings instance with values from database
|
||||||
|
"""
|
||||||
|
db = get_database()
|
||||||
|
|
||||||
|
# Helper to safely parse boolean
|
||||||
|
def parse_bool(value: Optional[str], default: bool) -> bool:
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
return value.lower() in ("on", "true", "1", "yes")
|
||||||
|
|
||||||
|
# Helper to safely parse int
|
||||||
|
def parse_int(value: Optional[str], default: int) -> int:
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Helper to safely parse float
|
||||||
|
def parse_float(value: Optional[str], default: float) -> float:
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Get system prompt from DB: None means not set (use default), "" means explicitly blank
|
||||||
|
system_prompt_value = db.get_config("default_system_prompt")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
api_key=db.get_config("api_key"),
|
||||||
|
base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
|
||||||
|
default_model=db.get_config("default_model"),
|
||||||
|
default_system_prompt=system_prompt_value,
|
||||||
|
stream_enabled=parse_bool(
|
||||||
|
db.get_config("stream_enabled"),
|
||||||
|
DEFAULT_STREAM_ENABLED
|
||||||
|
),
|
||||||
|
max_tokens=parse_int(
|
||||||
|
db.get_config("max_token"),
|
||||||
|
DEFAULT_MAX_TOKENS
|
||||||
|
),
|
||||||
|
cost_warning_threshold=parse_float(
|
||||||
|
db.get_config("cost_warning_threshold"),
|
||||||
|
DEFAULT_COST_WARNING_THRESHOLD
|
||||||
|
),
|
||||||
|
default_online_mode=parse_bool(
|
||||||
|
db.get_config("default_online_mode"),
|
||||||
|
DEFAULT_ONLINE_MODE
|
||||||
|
),
|
||||||
|
log_max_size_mb=parse_int(
|
||||||
|
db.get_config("log_max_size_mb"),
|
||||||
|
DEFAULT_LOG_MAX_SIZE_MB
|
||||||
|
),
|
||||||
|
log_backup_count=parse_int(
|
||||||
|
db.get_config("log_backup_count"),
|
||||||
|
DEFAULT_LOG_BACKUP_COUNT
|
||||||
|
),
|
||||||
|
log_level=db.get_config("log_level") or DEFAULT_LOG_LEVEL,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self) -> None:
|
||||||
|
"""Persist all settings to the database."""
|
||||||
|
db = get_database()
|
||||||
|
|
||||||
|
# Only save API key if it exists
|
||||||
|
if self.api_key:
|
||||||
|
db.set_config("api_key", self.api_key)
|
||||||
|
|
||||||
|
db.set_config("base_url", self.base_url)
|
||||||
|
|
||||||
|
if self.default_model:
|
||||||
|
db.set_config("default_model", self.default_model)
|
||||||
|
|
||||||
|
# Save system prompt: None means not set (don't save), otherwise save the value (even if "")
|
||||||
|
if self.default_system_prompt is not None:
|
||||||
|
db.set_config("default_system_prompt", self.default_system_prompt)
|
||||||
|
|
||||||
|
db.set_config("stream_enabled", "on" if self.stream_enabled else "off")
|
||||||
|
db.set_config("max_token", str(self.max_tokens))
|
||||||
|
db.set_config("cost_warning_threshold", str(self.cost_warning_threshold))
|
||||||
|
db.set_config("default_online_mode", "on" if self.default_online_mode else "off")
|
||||||
|
db.set_config("log_max_size_mb", str(self.log_max_size_mb))
|
||||||
|
db.set_config("log_backup_count", str(self.log_backup_count))
|
||||||
|
db.set_config("log_level", self.log_level)
|
||||||
|
|
||||||
|
def set_api_key(self, api_key: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the API key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: The new API key
|
||||||
|
"""
|
||||||
|
self.api_key = api_key.strip()
|
||||||
|
get_database().set_config("api_key", self.api_key)
|
||||||
|
|
||||||
|
def set_base_url(self, url: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the base URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The new base URL
|
||||||
|
"""
|
||||||
|
self.base_url = url.strip()
|
||||||
|
get_database().set_config("base_url", self.base_url)
|
||||||
|
|
||||||
|
def set_default_model(self, model_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the default model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The model ID to set as default
|
||||||
|
"""
|
||||||
|
self.default_model = model_id
|
||||||
|
get_database().set_config("default_model", model_id)
|
||||||
|
|
||||||
|
def set_default_system_prompt(self, prompt: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the default system prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The system prompt to use for all new sessions.
|
||||||
|
Empty string "" means blank prompt (no system message).
|
||||||
|
"""
|
||||||
|
self.default_system_prompt = prompt
|
||||||
|
get_database().set_config("default_system_prompt", prompt)
|
||||||
|
|
||||||
|
def clear_default_system_prompt(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear the custom system prompt and revert to hardcoded default.
|
||||||
|
|
||||||
|
This removes the custom prompt from the database, causing the
|
||||||
|
application to use the built-in DEFAULT_SYSTEM_PROMPT.
|
||||||
|
"""
|
||||||
|
self.default_system_prompt = None
|
||||||
|
# Remove from database to indicate "not set"
|
||||||
|
db = get_database()
|
||||||
|
with db._connection() as conn:
|
||||||
|
conn.execute("DELETE FROM config WHERE key = ?", ("default_system_prompt",))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def set_stream_enabled(self, enabled: bool) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the streaming preference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled: Whether to enable streaming
|
||||||
|
"""
|
||||||
|
self.stream_enabled = enabled
|
||||||
|
get_database().set_config("stream_enabled", "on" if enabled else "off")
|
||||||
|
|
||||||
|
def set_max_tokens(self, max_tokens: int) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the maximum tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_tokens: Maximum number of tokens
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If max_tokens is less than 1
|
||||||
|
"""
|
||||||
|
if max_tokens < 1:
|
||||||
|
raise ValueError("max_tokens must be at least 1")
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
get_database().set_config("max_token", str(max_tokens))
|
||||||
|
|
||||||
|
def set_cost_warning_threshold(self, threshold: float) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the cost warning threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: Cost threshold in USD
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If threshold is negative
|
||||||
|
"""
|
||||||
|
if threshold < 0:
|
||||||
|
raise ValueError("cost_warning_threshold must be non-negative")
|
||||||
|
self.cost_warning_threshold = threshold
|
||||||
|
get_database().set_config("cost_warning_threshold", str(threshold))
|
||||||
|
|
||||||
|
def set_default_online_mode(self, enabled: bool) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the default online mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled: Whether online mode should be enabled by default
|
||||||
|
"""
|
||||||
|
self.default_online_mode = enabled
|
||||||
|
get_database().set_config("default_online_mode", "on" if enabled else "off")
|
||||||
|
|
||||||
|
def set_log_level(self, level: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the log level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: The log level (debug/info/warning/error/critical)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If level is not valid
|
||||||
|
"""
|
||||||
|
level_lower = level.lower()
|
||||||
|
if level_lower not in VALID_LOG_LEVELS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid log level: {level}. "
|
||||||
|
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
|
||||||
|
)
|
||||||
|
self.log_level = level_lower
|
||||||
|
get_database().set_config("log_level", level_lower)
|
||||||
|
|
||||||
|
def set_log_max_size(self, size_mb: int) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the maximum log file size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size_mb: Maximum size in megabytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If size_mb is less than 1
|
||||||
|
"""
|
||||||
|
if size_mb < 1:
|
||||||
|
raise ValueError("log_max_size_mb must be at least 1")
|
||||||
|
# Cap at 100 MB for safety
|
||||||
|
self.log_max_size_mb = min(size_mb, 100)
|
||||||
|
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
|
||||||
|
|
||||||
|
|
||||||
|
# Global settings instance
|
||||||
|
_settings: Optional[Settings] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""
|
||||||
|
Get the global settings instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The shared Settings instance, loading from database if needed
|
||||||
|
"""
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
_settings = Settings.load()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def reload_settings() -> Settings:
|
||||||
|
"""
|
||||||
|
Force reload settings from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fresh Settings instance
|
||||||
|
"""
|
||||||
|
global _settings
|
||||||
|
_settings = Settings.load()
|
||||||
|
return _settings
|
||||||
451
oai/constants.py
Normal file
451
oai/constants.py
Normal file
@@ -0,0 +1,451 @@
|
|||||||
|
"""
|
||||||
|
Application-wide constants for oAI.
|
||||||
|
|
||||||
|
This module contains all configuration constants, default values, and static
|
||||||
|
definitions used throughout the application. Centralizing these values makes
|
||||||
|
the codebase easier to maintain and configure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Set, Dict, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Import version from single source of truth
|
||||||
|
from oai import __version__
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# APPLICATION METADATA
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
APP_NAME = "oAI"
|
||||||
|
APP_VERSION = __version__ # Single source of truth in oai/__init__.py
|
||||||
|
APP_URL = "https://iurl.no/oai"
|
||||||
|
APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration"
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# FILE PATHS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
HOME_DIR = Path.home()
|
||||||
|
CONFIG_DIR = HOME_DIR / ".config" / "oai"
|
||||||
|
CACHE_DIR = HOME_DIR / ".cache" / "oai"
|
||||||
|
HISTORY_FILE = CONFIG_DIR / "history.txt"
|
||||||
|
DATABASE_FILE = CONFIG_DIR / "oai_config.db"
|
||||||
|
LOG_FILE = CONFIG_DIR / "oai.log"
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# API CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
|
||||||
|
DEFAULT_STREAM_ENABLED = True
|
||||||
|
DEFAULT_MAX_TOKENS = 100_000
|
||||||
|
DEFAULT_ONLINE_MODE = False
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# DEFAULT SYSTEM PROMPT
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_PROMPT = (
|
||||||
|
"You are a knowledgeable and helpful AI assistant. Provide clear, accurate, "
|
||||||
|
"and well-structured responses. Be concise yet thorough. When uncertain about "
|
||||||
|
"something, acknowledge your limitations. For technical topics, include relevant "
|
||||||
|
"details and examples when helpful."
|
||||||
|
)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# PRICING DEFAULTS (per million tokens)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
DEFAULT_INPUT_PRICE = 3.0
|
||||||
|
DEFAULT_OUTPUT_PRICE = 15.0
|
||||||
|
|
||||||
|
MODEL_PRICING: Dict[str, float] = {
|
||||||
|
"input": DEFAULT_INPUT_PRICE,
|
||||||
|
"output": DEFAULT_OUTPUT_PRICE,
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# CREDIT ALERTS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
LOW_CREDIT_RATIO = 0.1 # Alert when credits < 10% of total
|
||||||
|
LOW_CREDIT_AMOUNT = 1.0 # Alert when credits < $1.00
|
||||||
|
DEFAULT_COST_WARNING_THRESHOLD = 0.01 # Alert when single message cost exceeds this
|
||||||
|
COST_WARNING_THRESHOLD = DEFAULT_COST_WARNING_THRESHOLD # Alias for convenience
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# LOGGING CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
DEFAULT_LOG_MAX_SIZE_MB = 10
|
||||||
|
DEFAULT_LOG_BACKUP_COUNT = 2
|
||||||
|
DEFAULT_LOG_LEVEL = "info"
|
||||||
|
|
||||||
|
VALID_LOG_LEVELS: Dict[str, int] = {
|
||||||
|
"debug": logging.DEBUG,
|
||||||
|
"info": logging.INFO,
|
||||||
|
"warning": logging.WARNING,
|
||||||
|
"error": logging.ERROR,
|
||||||
|
"critical": logging.CRITICAL,
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# FILE HANDLING
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Maximum file size for reading (10 MB)
|
||||||
|
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
# Content truncation threshold (50 KB)
|
||||||
|
CONTENT_TRUNCATION_THRESHOLD = 50 * 1024
|
||||||
|
|
||||||
|
# Maximum items in directory listing
|
||||||
|
MAX_LIST_ITEMS = 1000
|
||||||
|
|
||||||
|
# Supported code file extensions for syntax highlighting
|
||||||
|
SUPPORTED_CODE_EXTENSIONS: Set[str] = {
|
||||||
|
".py", ".js", ".ts", ".cs", ".java", ".c", ".cpp", ".h", ".hpp",
|
||||||
|
".rb", ".ruby", ".php", ".swift", ".kt", ".kts", ".go",
|
||||||
|
".sh", ".bat", ".ps1", ".r", ".scala", ".pl", ".lua", ".dart",
|
||||||
|
".elm", ".xml", ".json", ".yaml", ".yml", ".md", ".txt",
|
||||||
|
}
|
||||||
|
|
||||||
|
# All allowed file extensions for attachment
|
||||||
|
ALLOWED_FILE_EXTENSIONS: Set[str] = {
|
||||||
|
# Code files
|
||||||
|
".py", ".js", ".ts", ".jsx", ".tsx", ".vue", ".java", ".c", ".cpp", ".cc", ".cxx",
|
||||||
|
".h", ".hpp", ".hxx", ".rb", ".go", ".rs", ".swift", ".kt", ".kts", ".php",
|
||||||
|
".sh", ".bash", ".zsh", ".fish", ".bat", ".cmd", ".ps1",
|
||||||
|
# Data files
|
||||||
|
".json", ".csv", ".yaml", ".yml", ".toml", ".xml", ".sql", ".db", ".sqlite", ".sqlite3",
|
||||||
|
# Documents
|
||||||
|
".txt", ".md", ".log", ".conf", ".cfg", ".ini", ".env", ".properties",
|
||||||
|
# Images
|
||||||
|
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg", ".ico",
|
||||||
|
# Archives
|
||||||
|
".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz",
|
||||||
|
# Config files
|
||||||
|
".lock", ".gitignore", ".dockerignore", ".editorconfig", ".eslintrc",
|
||||||
|
".prettierrc", ".babelrc", ".nvmrc", ".npmrc",
|
||||||
|
# Binary/Compiled
|
||||||
|
".pyc", ".pyo", ".pyd", ".so", ".dll", ".dylib", ".exe", ".app",
|
||||||
|
".dmg", ".pkg", ".deb", ".rpm", ".apk", ".ipa",
|
||||||
|
# ML/AI
|
||||||
|
".pkl", ".pickle", ".joblib", ".npy", ".npz", ".safetensors", ".onnx",
|
||||||
|
".pt", ".pth", ".ckpt", ".pb", ".tflite", ".mlmodel", ".coreml", ".rknn",
|
||||||
|
# Data formats
|
||||||
|
".wasm", ".proto", ".graphql", ".graphqls", ".grpc", ".avro", ".parquet",
|
||||||
|
".orc", ".feather", ".arrow", ".hdf5", ".h5", ".mat", ".rdata", ".rds",
|
||||||
|
# Other
|
||||||
|
".pdf", ".class", ".jar", ".war",
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# SECURITY CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# System directories that should never be accessed
|
||||||
|
SYSTEM_DIRS_BLACKLIST: Set[str] = {
|
||||||
|
# macOS
|
||||||
|
"/System", "/Library", "/private", "/usr", "/bin", "/sbin",
|
||||||
|
# Linux
|
||||||
|
"/boot", "/dev", "/proc", "/sys", "/root",
|
||||||
|
# Windows
|
||||||
|
"C:\\Windows", "C:\\Program Files", "C:\\Program Files (x86)",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Directories to skip during file operations
|
||||||
|
SKIP_DIRECTORIES: Set[str] = {
|
||||||
|
# Python virtual environments
|
||||||
|
".venv", "venv", "env", "virtualenv",
|
||||||
|
"site-packages", "dist-packages",
|
||||||
|
# Python caches
|
||||||
|
"__pycache__", ".pytest_cache", ".mypy_cache",
|
||||||
|
# JavaScript/Node
|
||||||
|
"node_modules",
|
||||||
|
# Version control
|
||||||
|
".git", ".svn",
|
||||||
|
# IDEs
|
||||||
|
".idea", ".vscode",
|
||||||
|
# Build directories
|
||||||
|
"build", "dist", "eggs", ".eggs",
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# DATABASE QUERIES - SQL SAFETY
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Maximum query execution timeout (seconds)
|
||||||
|
MAX_QUERY_TIMEOUT = 5
|
||||||
|
|
||||||
|
# Maximum rows returned from queries
|
||||||
|
MAX_QUERY_RESULTS = 1000
|
||||||
|
|
||||||
|
# Default rows per query
|
||||||
|
DEFAULT_QUERY_LIMIT = 100
|
||||||
|
|
||||||
|
# Keywords that are blocked in database queries
|
||||||
|
DANGEROUS_SQL_KEYWORDS: Set[str] = {
|
||||||
|
"INSERT", "UPDATE", "DELETE", "DROP", "CREATE",
|
||||||
|
"ALTER", "TRUNCATE", "REPLACE", "ATTACH", "DETACH",
|
||||||
|
"PRAGMA", "VACUUM", "REINDEX",
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MCP CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Maximum tool call iterations per request
|
||||||
|
MAX_TOOL_LOOPS = 5
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# VALID COMMANDS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
VALID_COMMANDS: Set[str] = {
|
||||||
|
"/retry", "/online", "/memory", "/paste", "/export", "/save", "/load",
|
||||||
|
"/delete", "/list", "/prev", "/next", "/stats", "/middleout", "/reset",
|
||||||
|
"/info", "/model", "/maxtoken", "/system", "/config", "/credits", "/clear",
|
||||||
|
"/cl", "/help", "/mcp",
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# COMMAND HELP DATABASE
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
COMMAND_HELP: Dict[str, Dict[str, Any]] = {
|
||||||
|
"/clear": {
|
||||||
|
"aliases": ["/cl"],
|
||||||
|
"description": "Clear the terminal screen for a clean interface.",
|
||||||
|
"usage": "/clear\n/cl",
|
||||||
|
"examples": [
|
||||||
|
("Clear screen", "/clear"),
|
||||||
|
("Using short alias", "/cl"),
|
||||||
|
],
|
||||||
|
"notes": "You can also use the keyboard shortcut Ctrl+L.",
|
||||||
|
},
|
||||||
|
"/help": {
|
||||||
|
"description": "Display help information for commands.",
|
||||||
|
"usage": "/help [command|topic]",
|
||||||
|
"examples": [
|
||||||
|
("Show all commands", "/help"),
|
||||||
|
("Get help for a specific command", "/help /model"),
|
||||||
|
("Get detailed MCP help", "/help mcp"),
|
||||||
|
],
|
||||||
|
"notes": "Use /help without arguments to see the full command list.",
|
||||||
|
},
|
||||||
|
"mcp": {
|
||||||
|
"description": "Complete guide to MCP (Model Context Protocol).",
|
||||||
|
"usage": "See detailed examples below",
|
||||||
|
"examples": [],
|
||||||
|
"notes": """
|
||||||
|
MCP (Model Context Protocol) gives your AI assistant direct access to:
|
||||||
|
• Local files and folders (read, search, list)
|
||||||
|
• SQLite databases (inspect, search, query)
|
||||||
|
|
||||||
|
FILE MODE (default):
|
||||||
|
/mcp on Start MCP server
|
||||||
|
/mcp add ~/Documents Grant access to folder
|
||||||
|
/mcp list View all allowed folders
|
||||||
|
|
||||||
|
DATABASE MODE:
|
||||||
|
/mcp add db ~/app/data.db Add specific database
|
||||||
|
/mcp db list View all databases
|
||||||
|
/mcp db 1 Work with database #1
|
||||||
|
/mcp files Switch back to file mode
|
||||||
|
|
||||||
|
WRITE MODE (optional):
|
||||||
|
/mcp write on Enable file modifications
|
||||||
|
/mcp write off Disable write mode (back to read-only)
|
||||||
|
|
||||||
|
For command-specific help: /help /mcp
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
"/mcp": {
|
||||||
|
"description": "Manage MCP for local file access and SQLite database querying.",
|
||||||
|
"usage": "/mcp <command> [args]",
|
||||||
|
"examples": [
|
||||||
|
("Enable MCP server", "/mcp on"),
|
||||||
|
("Disable MCP server", "/mcp off"),
|
||||||
|
("Show MCP status", "/mcp status"),
|
||||||
|
("", ""),
|
||||||
|
("━━━ FILE MODE ━━━", ""),
|
||||||
|
("Add folder for file access", "/mcp add ~/Documents"),
|
||||||
|
("Remove folder", "/mcp remove ~/Desktop"),
|
||||||
|
("List allowed folders", "/mcp list"),
|
||||||
|
("Enable write mode", "/mcp write on"),
|
||||||
|
("", ""),
|
||||||
|
("━━━ DATABASE MODE ━━━", ""),
|
||||||
|
("Add SQLite database", "/mcp add db ~/app/data.db"),
|
||||||
|
("List all databases", "/mcp db list"),
|
||||||
|
("Switch to database #1", "/mcp db 1"),
|
||||||
|
("Switch back to file mode", "/mcp files"),
|
||||||
|
],
|
||||||
|
"notes": "MCP allows AI to read local files and query SQLite databases.",
|
||||||
|
},
|
||||||
|
"/memory": {
|
||||||
|
"description": "Toggle conversation memory.",
|
||||||
|
"usage": "/memory [on|off]",
|
||||||
|
"examples": [
|
||||||
|
("Check current memory status", "/memory"),
|
||||||
|
("Enable conversation memory", "/memory on"),
|
||||||
|
("Disable memory (save costs)", "/memory off"),
|
||||||
|
],
|
||||||
|
"notes": "Memory is ON by default. Disabling saves tokens.",
|
||||||
|
},
|
||||||
|
"/online": {
|
||||||
|
"description": "Enable or disable online mode (web search).",
|
||||||
|
"usage": "/online [on|off]",
|
||||||
|
"examples": [
|
||||||
|
("Check online mode status", "/online"),
|
||||||
|
("Enable web search", "/online on"),
|
||||||
|
("Disable web search", "/online off"),
|
||||||
|
],
|
||||||
|
"notes": "Not all models support online mode.",
|
||||||
|
},
|
||||||
|
"/paste": {
|
||||||
|
"description": "Paste plain text from clipboard and send to the AI.",
|
||||||
|
"usage": "/paste [prompt]",
|
||||||
|
"examples": [
|
||||||
|
("Paste clipboard content", "/paste"),
|
||||||
|
("Paste with a question", "/paste Explain this code"),
|
||||||
|
],
|
||||||
|
"notes": "Only plain text is supported.",
|
||||||
|
},
|
||||||
|
"/retry": {
|
||||||
|
"description": "Resend the last prompt from conversation history.",
|
||||||
|
"usage": "/retry",
|
||||||
|
"examples": [("Retry last message", "/retry")],
|
||||||
|
"notes": "Requires at least one message in history.",
|
||||||
|
},
|
||||||
|
"/next": {
|
||||||
|
"description": "View the next response in conversation history.",
|
||||||
|
"usage": "/next",
|
||||||
|
"examples": [("Navigate to next response", "/next")],
|
||||||
|
"notes": "Use /prev to go backward.",
|
||||||
|
},
|
||||||
|
"/prev": {
|
||||||
|
"description": "View the previous response in conversation history.",
|
||||||
|
"usage": "/prev",
|
||||||
|
"examples": [("Navigate to previous response", "/prev")],
|
||||||
|
"notes": "Use /next to go forward.",
|
||||||
|
},
|
||||||
|
"/reset": {
|
||||||
|
"description": "Clear conversation history and reset system prompt.",
|
||||||
|
"usage": "/reset",
|
||||||
|
"examples": [("Reset conversation", "/reset")],
|
||||||
|
"notes": "Requires confirmation.",
|
||||||
|
},
|
||||||
|
"/info": {
|
||||||
|
"description": "Display detailed information about a model.",
|
||||||
|
"usage": "/info [model_id]",
|
||||||
|
"examples": [
|
||||||
|
("Show current model info", "/info"),
|
||||||
|
("Show specific model info", "/info gpt-4o"),
|
||||||
|
],
|
||||||
|
"notes": "Shows pricing, capabilities, and context length.",
|
||||||
|
},
|
||||||
|
"/model": {
|
||||||
|
"description": "Select or change the AI model.",
|
||||||
|
"usage": "/model [search_term]",
|
||||||
|
"examples": [
|
||||||
|
("List all models", "/model"),
|
||||||
|
("Search for GPT models", "/model gpt"),
|
||||||
|
("Search for Claude models", "/model claude"),
|
||||||
|
],
|
||||||
|
"notes": "Models are numbered for easy selection.",
|
||||||
|
},
|
||||||
|
"/config": {
|
||||||
|
"description": "View or modify application configuration.",
|
||||||
|
"usage": "/config [setting] [value]",
|
||||||
|
"examples": [
|
||||||
|
("View all settings", "/config"),
|
||||||
|
("Set API key", "/config api"),
|
||||||
|
("Set default model", "/config model"),
|
||||||
|
("Set system prompt", "/config system You are a helpful assistant"),
|
||||||
|
("Enable streaming", "/config stream on"),
|
||||||
|
],
|
||||||
|
"notes": "Available: api, url, model, system, stream, costwarning, maxtoken, online, loglevel.",
|
||||||
|
},
|
||||||
|
"/maxtoken": {
|
||||||
|
"description": "Set a temporary session token limit.",
|
||||||
|
"usage": "/maxtoken [value]",
|
||||||
|
"examples": [
|
||||||
|
("View current session limit", "/maxtoken"),
|
||||||
|
("Set session limit to 2000", "/maxtoken 2000"),
|
||||||
|
],
|
||||||
|
"notes": "Cannot exceed stored max token limit.",
|
||||||
|
},
|
||||||
|
"/system": {
|
||||||
|
"description": "Set or clear the session-level system prompt.",
|
||||||
|
"usage": "/system [prompt|clear|default <prompt>]",
|
||||||
|
"examples": [
|
||||||
|
("View current system prompt", "/system"),
|
||||||
|
("Set as Python expert", "/system You are a Python expert"),
|
||||||
|
("Multiline with newlines", r"/system You are an expert.\nBe clear and concise."),
|
||||||
|
("Save as default", "/system default You are a helpful assistant"),
|
||||||
|
("Revert to default", "/system clear"),
|
||||||
|
("Blank prompt", '/system ""'),
|
||||||
|
],
|
||||||
|
"notes": r"Use \n for newlines. /system clear reverts to hardcoded default.",
|
||||||
|
},
|
||||||
|
"/save": {
|
||||||
|
"description": "Save the current conversation history.",
|
||||||
|
"usage": "/save <name>",
|
||||||
|
"examples": [("Save conversation", "/save my_chat")],
|
||||||
|
"notes": "Saved conversations can be loaded later with /load.",
|
||||||
|
},
|
||||||
|
"/load": {
|
||||||
|
"description": "Load a saved conversation.",
|
||||||
|
"usage": "/load <name|number>",
|
||||||
|
"examples": [
|
||||||
|
("Load by name", "/load my_chat"),
|
||||||
|
("Load by number from /list", "/load 3"),
|
||||||
|
],
|
||||||
|
"notes": "Use /list to see numbered conversations.",
|
||||||
|
},
|
||||||
|
"/delete": {
|
||||||
|
"description": "Delete a saved conversation.",
|
||||||
|
"usage": "/delete <name|number>",
|
||||||
|
"examples": [("Delete by name", "/delete my_chat")],
|
||||||
|
"notes": "Requires confirmation. Cannot be undone.",
|
||||||
|
},
|
||||||
|
"/list": {
|
||||||
|
"description": "List all saved conversations.",
|
||||||
|
"usage": "/list",
|
||||||
|
"examples": [("Show saved conversations", "/list")],
|
||||||
|
"notes": "Conversations are numbered for use with /load and /delete.",
|
||||||
|
},
|
||||||
|
"/export": {
|
||||||
|
"description": "Export the current conversation to a file.",
|
||||||
|
"usage": "/export <format> <filename>",
|
||||||
|
"examples": [
|
||||||
|
("Export as Markdown", "/export md notes.md"),
|
||||||
|
("Export as JSON", "/export json conversation.json"),
|
||||||
|
("Export as HTML", "/export html report.html"),
|
||||||
|
],
|
||||||
|
"notes": "Available formats: md, json, html.",
|
||||||
|
},
|
||||||
|
"/stats": {
|
||||||
|
"description": "Display session statistics.",
|
||||||
|
"usage": "/stats",
|
||||||
|
"examples": [("View session statistics", "/stats")],
|
||||||
|
"notes": "Shows tokens, costs, and credits.",
|
||||||
|
},
|
||||||
|
"/credits": {
|
||||||
|
"description": "Display your OpenRouter account credits.",
|
||||||
|
"usage": "/credits",
|
||||||
|
"examples": [("Check credits", "/credits")],
|
||||||
|
"notes": "Shows total, used, and remaining credits.",
|
||||||
|
},
|
||||||
|
"/middleout": {
|
||||||
|
"description": "Toggle middle-out transform for long prompts.",
|
||||||
|
"usage": "/middleout [on|off]",
|
||||||
|
"examples": [
|
||||||
|
("Check status", "/middleout"),
|
||||||
|
("Enable compression", "/middleout on"),
|
||||||
|
],
|
||||||
|
"notes": "Compresses prompts exceeding context size.",
|
||||||
|
},
|
||||||
|
}
|
||||||
14
oai/core/__init__.py
Normal file
14
oai/core/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Core functionality for oAI.
|
||||||
|
|
||||||
|
This module provides the main session management and AI client
|
||||||
|
classes that power the chat application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.core.session import ChatSession
|
||||||
|
from oai.core.client import AIClient
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChatSession",
|
||||||
|
"AIClient",
|
||||||
|
]
|
||||||
422
oai/core/client.py
Normal file
422
oai/core/client.py
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
"""
|
||||||
|
AI Client for oAI.
|
||||||
|
|
||||||
|
This module provides a high-level client for interacting with AI models
|
||||||
|
through the provider abstraction layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING
|
||||||
|
from oai.providers.base import (
|
||||||
|
AIProvider,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResponse,
|
||||||
|
ModelInfo,
|
||||||
|
StreamChunk,
|
||||||
|
ToolCall,
|
||||||
|
UsageStats,
|
||||||
|
)
|
||||||
|
from oai.providers.openrouter import OpenRouterProvider
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class AIClient:
|
||||||
|
"""
|
||||||
|
High-level AI client for chat interactions.
|
||||||
|
|
||||||
|
Provides a simplified interface for sending chat requests,
|
||||||
|
handling streaming, and managing tool calls.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
provider: The underlying AI provider
|
||||||
|
default_model: Default model ID to use
|
||||||
|
http_headers: Custom HTTP headers for requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
provider_class: type = OpenRouterProvider,
|
||||||
|
app_name: str = APP_NAME,
|
||||||
|
app_url: str = APP_URL,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the AI client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication
|
||||||
|
base_url: Optional custom base URL
|
||||||
|
provider_class: Provider class to use (default: OpenRouterProvider)
|
||||||
|
app_name: Application name for headers
|
||||||
|
app_url: Application URL for headers
|
||||||
|
"""
|
||||||
|
self.provider: AIProvider = provider_class(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
app_name=app_name,
|
||||||
|
app_url=app_url,
|
||||||
|
)
|
||||||
|
self.default_model: Optional[str] = None
|
||||||
|
self.logger = get_logger()
|
||||||
|
|
||||||
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get available models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_text_only: Whether to exclude video-only models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ModelInfo objects
|
||||||
|
"""
|
||||||
|
return self.provider.list_models(filter_text_only=filter_text_only)
|
||||||
|
|
||||||
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelInfo or None if not found
|
||||||
|
"""
|
||||||
|
return self.provider.get_model(model_id)
|
||||||
|
|
||||||
|
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data for provider-specific fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw model dictionary or None
|
||||||
|
"""
|
||||||
|
if hasattr(self.provider, "get_raw_model"):
|
||||||
|
return self.provider.get_raw_model(model_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model: Optional[str] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
online: bool = False,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send a chat request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
model: Model ID (uses default if not specified)
|
||||||
|
stream: Whether to stream the response
|
||||||
|
max_tokens: Maximum tokens in response
|
||||||
|
temperature: Sampling temperature
|
||||||
|
tools: Tool definitions for function calling
|
||||||
|
tool_choice: Tool selection mode
|
||||||
|
system_prompt: System prompt to prepend
|
||||||
|
online: Whether to enable online mode
|
||||||
|
transforms: List of transforms (e.g., ["middle-out"])
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no model specified and no default set
|
||||||
|
"""
|
||||||
|
model_id = model or self.default_model
|
||||||
|
if not model_id:
|
||||||
|
raise ValueError("No model specified and no default set")
|
||||||
|
|
||||||
|
# Apply online mode suffix
|
||||||
|
if online and hasattr(self.provider, "get_effective_model_id"):
|
||||||
|
model_id = self.provider.get_effective_model_id(model_id, True)
|
||||||
|
|
||||||
|
# Convert dict messages to ChatMessage objects
|
||||||
|
chat_messages = []
|
||||||
|
|
||||||
|
# Add system prompt if provided
|
||||||
|
if system_prompt:
|
||||||
|
chat_messages.append(ChatMessage(role="system", content=system_prompt))
|
||||||
|
|
||||||
|
# Convert message dicts
|
||||||
|
for msg in messages:
|
||||||
|
# Convert tool_calls dicts to ToolCall objects if present
|
||||||
|
tool_calls_data = msg.get("tool_calls")
|
||||||
|
tool_calls_obj = None
|
||||||
|
if tool_calls_data:
|
||||||
|
from oai.providers.base import ToolCall, ToolFunction
|
||||||
|
tool_calls_obj = []
|
||||||
|
for tc in tool_calls_data:
|
||||||
|
# Handle both ToolCall objects and dicts
|
||||||
|
if isinstance(tc, ToolCall):
|
||||||
|
tool_calls_obj.append(tc)
|
||||||
|
elif isinstance(tc, dict):
|
||||||
|
func_data = tc.get("function", {})
|
||||||
|
tool_calls_obj.append(
|
||||||
|
ToolCall(
|
||||||
|
id=tc.get("id", ""),
|
||||||
|
type=tc.get("type", "function"),
|
||||||
|
function=ToolFunction(
|
||||||
|
name=func_data.get("name", ""),
|
||||||
|
arguments=func_data.get("arguments", "{}"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role=msg.get("role", "user"),
|
||||||
|
content=msg.get("content"),
|
||||||
|
tool_calls=tool_calls_obj,
|
||||||
|
tool_call_id=msg.get("tool_call_id"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"Sending chat request: model={model_id}, "
|
||||||
|
f"messages={len(chat_messages)}, stream={stream}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.provider.chat(
|
||||||
|
model=model_id,
|
||||||
|
messages=chat_messages,
|
||||||
|
stream=stream,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
def chat_with_tools(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
tool_executor: Callable[[str, Dict[str, Any]], Dict[str, Any]],
|
||||||
|
model: Optional[str] = None,
|
||||||
|
max_loops: int = 5,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
on_tool_call: Optional[Callable[[ToolCall], None]] = None,
|
||||||
|
on_tool_result: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""
|
||||||
|
Send a chat request with automatic tool call handling.
|
||||||
|
|
||||||
|
Executes tool calls returned by the model and continues
|
||||||
|
the conversation until no more tool calls are requested.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Initial messages
|
||||||
|
tools: Tool definitions
|
||||||
|
tool_executor: Function to execute tool calls
|
||||||
|
model: Model ID
|
||||||
|
max_loops: Maximum tool call iterations
|
||||||
|
max_tokens: Maximum response tokens
|
||||||
|
system_prompt: System prompt
|
||||||
|
on_tool_call: Callback when tool is called
|
||||||
|
on_tool_result: Callback when tool returns result
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final ChatResponse after all tool calls complete
|
||||||
|
"""
|
||||||
|
model_id = model or self.default_model
|
||||||
|
if not model_id:
|
||||||
|
raise ValueError("No model specified and no default set")
|
||||||
|
|
||||||
|
# Build initial messages
|
||||||
|
chat_messages = []
|
||||||
|
if system_prompt:
|
||||||
|
chat_messages.append({"role": "system", "content": system_prompt})
|
||||||
|
chat_messages.extend(messages)
|
||||||
|
|
||||||
|
loop_count = 0
|
||||||
|
current_response: Optional[ChatResponse] = None
|
||||||
|
|
||||||
|
while loop_count < max_loops:
|
||||||
|
# Send request
|
||||||
|
response = self.chat(
|
||||||
|
messages=chat_messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(response, ChatResponse):
|
||||||
|
raise ValueError("Expected non-streaming response")
|
||||||
|
|
||||||
|
current_response = response
|
||||||
|
|
||||||
|
# Check for tool calls
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
if not tool_calls:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.logger.info(f"Model requested {len(tool_calls)} tool call(s)")
|
||||||
|
|
||||||
|
# Process each tool call
|
||||||
|
tool_results = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
if on_tool_call:
|
||||||
|
on_tool_call(tc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = json.loads(tc.function.arguments)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.error(f"Failed to parse tool arguments: {e}")
|
||||||
|
result = {"error": f"Invalid arguments: {e}"}
|
||||||
|
else:
|
||||||
|
result = tool_executor(tc.function.name, args)
|
||||||
|
|
||||||
|
if on_tool_result:
|
||||||
|
on_tool_result(tc.function.name, result)
|
||||||
|
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": tc.function.name,
|
||||||
|
"content": json.dumps(result),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add assistant message with tool calls
|
||||||
|
assistant_msg = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response.content,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": tc.type,
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in tool_calls
|
||||||
|
],
|
||||||
|
}
|
||||||
|
chat_messages.append(assistant_msg)
|
||||||
|
chat_messages.extend(tool_results)
|
||||||
|
|
||||||
|
loop_count += 1
|
||||||
|
|
||||||
|
if loop_count >= max_loops:
|
||||||
|
self.logger.warning(f"Reached max tool call loops ({max_loops})")
|
||||||
|
|
||||||
|
return current_response
|
||||||
|
|
||||||
|
def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
online: bool = False,
|
||||||
|
on_chunk: Optional[Callable[[StreamChunk], None]] = None,
|
||||||
|
) -> tuple[str, Optional[UsageStats]]:
|
||||||
|
"""
|
||||||
|
Stream a chat response and collect the full text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Chat messages
|
||||||
|
model: Model ID
|
||||||
|
max_tokens: Maximum tokens
|
||||||
|
system_prompt: System prompt
|
||||||
|
online: Online mode
|
||||||
|
on_chunk: Optional callback for each chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (full_response_text, usage_stats)
|
||||||
|
"""
|
||||||
|
response = self.chat(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
online=online,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, ChatResponse):
|
||||||
|
# Not actually streaming
|
||||||
|
return response.content or "", response.usage
|
||||||
|
|
||||||
|
full_text = ""
|
||||||
|
usage: Optional[UsageStats] = None
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.error:
|
||||||
|
self.logger.error(f"Stream error: {chunk.error}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if chunk.delta_content:
|
||||||
|
full_text += chunk.delta_content
|
||||||
|
if on_chunk:
|
||||||
|
on_chunk(chunk)
|
||||||
|
|
||||||
|
if chunk.usage:
|
||||||
|
usage = chunk.usage
|
||||||
|
|
||||||
|
return full_text, usage
|
||||||
|
|
||||||
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get account credit information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Credit info dict or None if unavailable
|
||||||
|
"""
|
||||||
|
return self.provider.get_credits()
|
||||||
|
|
||||||
|
def estimate_cost(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
input_tokens: int,
|
||||||
|
output_tokens: int,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Estimate cost for a completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model ID
|
||||||
|
input_tokens: Number of input tokens
|
||||||
|
output_tokens: Number of output tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated cost in USD
|
||||||
|
"""
|
||||||
|
if hasattr(self.provider, "estimate_cost"):
|
||||||
|
return self.provider.estimate_cost(model_id, input_tokens, output_tokens)
|
||||||
|
|
||||||
|
# Fallback to default pricing
|
||||||
|
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
|
||||||
|
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
|
||||||
|
return input_cost + output_cost
|
||||||
|
|
||||||
|
def set_default_model(self, model_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Set the default model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model ID to use as default
|
||||||
|
"""
|
||||||
|
self.default_model = model_id
|
||||||
|
self.logger.info(f"Default model set to: {model_id}")
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the provider's model cache."""
|
||||||
|
if hasattr(self.provider, "clear_cache"):
|
||||||
|
self.provider.clear_cache()
|
||||||
891
oai/core/session.py
Normal file
891
oai/core/session.py
Normal file
@@ -0,0 +1,891 @@
|
|||||||
|
"""
|
||||||
|
Chat session management for oAI.
|
||||||
|
|
||||||
|
This module provides the ChatSession class that manages an interactive
|
||||||
|
chat session including history, state, and message handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
|
from oai.commands.registry import CommandContext, CommandResult, registry
|
||||||
|
from oai.config.database import Database
|
||||||
|
from oai.config.settings import Settings
|
||||||
|
from oai.constants import (
|
||||||
|
COST_WARNING_THRESHOLD,
|
||||||
|
LOW_CREDIT_AMOUNT,
|
||||||
|
LOW_CREDIT_RATIO,
|
||||||
|
)
|
||||||
|
from oai.core.client import AIClient
|
||||||
|
from oai.mcp.manager import MCPManager
|
||||||
|
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionStats:
|
||||||
|
"""
|
||||||
|
Statistics for the current session.
|
||||||
|
|
||||||
|
Tracks tokens, costs, and message counts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_input_tokens: int = 0
|
||||||
|
total_output_tokens: int = 0
|
||||||
|
total_cost: float = 0.0
|
||||||
|
message_count: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
"""Get total token count."""
|
||||||
|
return self.total_input_tokens + self.total_output_tokens
|
||||||
|
|
||||||
|
def add_usage(self, usage: Optional[UsageStats], cost: float = 0.0) -> None:
|
||||||
|
"""
|
||||||
|
Add usage stats from a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage: Usage statistics
|
||||||
|
cost: Cost if not in usage
|
||||||
|
"""
|
||||||
|
if usage:
|
||||||
|
self.total_input_tokens += usage.prompt_tokens
|
||||||
|
self.total_output_tokens += usage.completion_tokens
|
||||||
|
if usage.total_cost_usd:
|
||||||
|
self.total_cost += usage.total_cost_usd
|
||||||
|
else:
|
||||||
|
self.total_cost += cost
|
||||||
|
else:
|
||||||
|
self.total_cost += cost
|
||||||
|
self.message_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HistoryEntry:
|
||||||
|
"""
|
||||||
|
A single entry in the conversation history.
|
||||||
|
|
||||||
|
Stores the user prompt, assistant response, and metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: str
|
||||||
|
response: str
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
msg_cost: float = 0.0
|
||||||
|
timestamp: Optional[float] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"prompt": self.prompt,
|
||||||
|
"response": self.response,
|
||||||
|
"prompt_tokens": self.prompt_tokens,
|
||||||
|
"completion_tokens": self.completion_tokens,
|
||||||
|
"msg_cost": self.msg_cost,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSession:
|
||||||
|
"""
|
||||||
|
Manages an interactive chat session.
|
||||||
|
|
||||||
|
Handles conversation history, state management, command processing,
|
||||||
|
and communication with the AI client.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
client: AI client for API requests
|
||||||
|
settings: Application settings
|
||||||
|
mcp_manager: MCP manager for file/database access
|
||||||
|
history: Conversation history
|
||||||
|
stats: Session statistics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: AIClient,
|
||||||
|
settings: Settings,
|
||||||
|
mcp_manager: Optional[MCPManager] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize a chat session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: AI client instance
|
||||||
|
settings: Application settings
|
||||||
|
mcp_manager: Optional MCP manager
|
||||||
|
"""
|
||||||
|
self.client = client
|
||||||
|
self.settings = settings
|
||||||
|
self.mcp_manager = mcp_manager
|
||||||
|
self.db = Database()
|
||||||
|
|
||||||
|
self.history: List[HistoryEntry] = []
|
||||||
|
self.stats = SessionStats()
|
||||||
|
|
||||||
|
# Session state
|
||||||
|
self.system_prompt: str = settings.effective_system_prompt
|
||||||
|
self.memory_enabled: bool = True
|
||||||
|
self.memory_start_index: int = 0
|
||||||
|
self.online_enabled: bool = settings.default_online_mode
|
||||||
|
self.middle_out_enabled: bool = False
|
||||||
|
self.session_max_token: int = 0
|
||||||
|
self.current_index: int = 0
|
||||||
|
|
||||||
|
# Selected model
|
||||||
|
self.selected_model: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
self.logger = get_logger()
|
||||||
|
|
||||||
|
def get_context(self) -> CommandContext:
|
||||||
|
"""
|
||||||
|
Get the current command context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CommandContext with current session state
|
||||||
|
"""
|
||||||
|
return CommandContext(
|
||||||
|
settings=self.settings,
|
||||||
|
provider=self.client.provider,
|
||||||
|
mcp_manager=self.mcp_manager,
|
||||||
|
selected_model_raw=self.selected_model,
|
||||||
|
session_history=[e.to_dict() for e in self.history],
|
||||||
|
session_system_prompt=self.system_prompt,
|
||||||
|
memory_enabled=self.memory_enabled,
|
||||||
|
memory_start_index=self.memory_start_index,
|
||||||
|
online_enabled=self.online_enabled,
|
||||||
|
middle_out_enabled=self.middle_out_enabled,
|
||||||
|
session_max_token=self.session_max_token,
|
||||||
|
total_input_tokens=self.stats.total_input_tokens,
|
||||||
|
total_output_tokens=self.stats.total_output_tokens,
|
||||||
|
total_cost=self.stats.total_cost,
|
||||||
|
message_count=self.stats.message_count,
|
||||||
|
current_index=self.current_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_model(self, model: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Set the selected model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Raw model dictionary
|
||||||
|
"""
|
||||||
|
self.selected_model = model
|
||||||
|
self.client.set_default_model(model["id"])
|
||||||
|
self.logger.info(f"Model selected: {model['id']}")
|
||||||
|
|
||||||
|
def build_api_messages(self, user_input: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Build the messages array for an API request.
|
||||||
|
|
||||||
|
Includes system prompt, history (if memory enabled), and current input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: Current user input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of message dictionaries
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# Add system prompt
|
||||||
|
if self.system_prompt:
|
||||||
|
messages.append({"role": "system", "content": self.system_prompt})
|
||||||
|
|
||||||
|
# Add database context if in database mode
|
||||||
|
if self.mcp_manager and self.mcp_manager.enabled:
|
||||||
|
if self.mcp_manager.mode == "database" and self.mcp_manager.selected_db_index is not None:
|
||||||
|
db = self.mcp_manager.databases[self.mcp_manager.selected_db_index]
|
||||||
|
db_context = (
|
||||||
|
f"You are connected to SQLite database: {db['name']}\n"
|
||||||
|
f"Available tables: {', '.join(db['tables'])}\n\n"
|
||||||
|
"Use inspect_database, search_database, or query_database tools. "
|
||||||
|
"All queries are read-only."
|
||||||
|
)
|
||||||
|
messages.append({"role": "system", "content": db_context})
|
||||||
|
|
||||||
|
# Add history if memory enabled
|
||||||
|
if self.memory_enabled:
|
||||||
|
for i in range(self.memory_start_index, len(self.history)):
|
||||||
|
entry = self.history[i]
|
||||||
|
messages.append({"role": "user", "content": entry.prompt})
|
||||||
|
messages.append({"role": "assistant", "content": entry.response})
|
||||||
|
|
||||||
|
# Add current message
|
||||||
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_mcp_tools(self) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get MCP tool definitions if available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool schemas or None
|
||||||
|
"""
|
||||||
|
if not self.mcp_manager or not self.mcp_manager.enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not self.selected_model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if model supports tools
|
||||||
|
supported_params = self.selected_model.get("supported_parameters", [])
|
||||||
|
if "tools" not in supported_params and "functions" not in supported_params:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.mcp_manager.get_tools_schema()
|
||||||
|
|
||||||
|
async def execute_tool(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_args: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute an MCP tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
tool_args: Tool arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
"""
|
||||||
|
if not self.mcp_manager:
|
||||||
|
return {"error": "MCP not available"}
|
||||||
|
|
||||||
|
return await self.mcp_manager.call_tool(tool_name, **tool_args)
|
||||||
|
|
||||||
|
def send_message(
|
||||||
|
self,
|
||||||
|
user_input: str,
|
||||||
|
stream: bool = True,
|
||||||
|
on_stream_chunk: Optional[Callable[[str], None]] = None,
|
||||||
|
) -> Tuple[str, Optional[UsageStats], float]:
|
||||||
|
"""
|
||||||
|
Send a message and get a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: User's input text
|
||||||
|
stream: Whether to stream the response
|
||||||
|
on_stream_chunk: Callback for stream chunks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (response_text, usage_stats, response_time)
|
||||||
|
"""
|
||||||
|
if not self.selected_model:
|
||||||
|
raise ValueError("No model selected")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
messages = self.build_api_messages(user_input)
|
||||||
|
|
||||||
|
# Get MCP tools
|
||||||
|
tools = self.get_mcp_tools()
|
||||||
|
if tools:
|
||||||
|
# Disable streaming when tools are present
|
||||||
|
stream = False
|
||||||
|
|
||||||
|
# Build request parameters
|
||||||
|
model_id = self.selected_model["id"]
|
||||||
|
if self.online_enabled:
|
||||||
|
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||||
|
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||||
|
|
||||||
|
transforms = ["middle-out"] if self.middle_out_enabled else None
|
||||||
|
|
||||||
|
max_tokens = None
|
||||||
|
if self.session_max_token > 0:
|
||||||
|
max_tokens = self.session_max_token
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
# Use tool handling flow
|
||||||
|
response = self._send_with_tools(
|
||||||
|
messages=messages,
|
||||||
|
model_id=model_id,
|
||||||
|
tools=tools,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
response_time = time.time() - start_time
|
||||||
|
return response.content or "", response.usage, response_time
|
||||||
|
|
||||||
|
elif stream:
|
||||||
|
# Use streaming flow
|
||||||
|
full_text, usage = self._stream_response(
|
||||||
|
messages=messages,
|
||||||
|
model_id=model_id,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
on_chunk=on_stream_chunk,
|
||||||
|
)
|
||||||
|
response_time = time.time() - start_time
|
||||||
|
return full_text, usage, response_time
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Non-streaming request
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
response_time = time.time() - start_time
|
||||||
|
|
||||||
|
if isinstance(response, ChatResponse):
|
||||||
|
return response.content or "", response.usage, response_time
|
||||||
|
else:
|
||||||
|
return "", None, response_time
|
||||||
|
|
||||||
|
def _send_with_tools(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model_id: str,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""
|
||||||
|
Send a request with tool call handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: API messages
|
||||||
|
model_id: Model ID
|
||||||
|
tools: Tool definitions
|
||||||
|
max_tokens: Max tokens
|
||||||
|
transforms: Transforms list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final ChatResponse
|
||||||
|
"""
|
||||||
|
max_loops = 5
|
||||||
|
loop_count = 0
|
||||||
|
api_messages = list(messages)
|
||||||
|
|
||||||
|
while loop_count < max_loops:
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=api_messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(response, ChatResponse):
|
||||||
|
raise ValueError("Expected ChatResponse")
|
||||||
|
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
if not tool_calls:
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Tool calls requested by AI
|
||||||
|
|
||||||
|
tool_results = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
try:
|
||||||
|
args = json.loads(tc.function.arguments)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.error(f"Failed to parse tool arguments: {e}")
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": tc.function.name,
|
||||||
|
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Display tool call
|
||||||
|
args_display = ", ".join(
|
||||||
|
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
||||||
|
for k, v in args.items()
|
||||||
|
)
|
||||||
|
# Executing tool: {tc.function.name}
|
||||||
|
|
||||||
|
# Execute tool
|
||||||
|
result = asyncio.run(self.execute_tool(tc.function.name, args))
|
||||||
|
|
||||||
|
if "error" in result:
|
||||||
|
# Tool execution error logged
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Tool execution successful
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": tc.function.name,
|
||||||
|
"content": json.dumps(result),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add assistant message with tool calls
|
||||||
|
api_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response.content,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": tc.type,
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in tool_calls
|
||||||
|
],
|
||||||
|
})
|
||||||
|
api_messages.extend(tool_results)
|
||||||
|
|
||||||
|
# Processing tool results
|
||||||
|
loop_count += 1
|
||||||
|
|
||||||
|
self.logger.warning(f"Reached max tool loops ({max_loops})")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_response(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model_id: str,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
on_chunk: Optional[Callable[[str], None]] = None,
|
||||||
|
) -> Tuple[str, Optional[UsageStats]]:
|
||||||
|
"""
|
||||||
|
Stream a response with live display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: API messages
|
||||||
|
model_id: Model ID
|
||||||
|
max_tokens: Max tokens
|
||||||
|
transforms: Transforms
|
||||||
|
on_chunk: Callback for chunks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (full_text, usage)
|
||||||
|
"""
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, ChatResponse):
|
||||||
|
return response.content or "", response.usage
|
||||||
|
|
||||||
|
full_text = ""
|
||||||
|
usage: Optional[UsageStats] = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.error:
|
||||||
|
self.logger.error(f"Stream error: {chunk.error}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if chunk.delta_content:
|
||||||
|
full_text += chunk.delta_content
|
||||||
|
if on_chunk:
|
||||||
|
on_chunk(chunk.delta_content)
|
||||||
|
|
||||||
|
if chunk.usage:
|
||||||
|
usage = chunk.usage
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
self.logger.info("Streaming interrupted")
|
||||||
|
return "", None
|
||||||
|
|
||||||
|
return full_text, usage
|
||||||
|
|
||||||
|
# ========== ASYNC METHODS FOR TUI ==========
|
||||||
|
|
||||||
|
async def send_message_async(
|
||||||
|
self,
|
||||||
|
user_input: str,
|
||||||
|
stream: bool = True,
|
||||||
|
) -> AsyncIterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Async version of send_message for Textual TUI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: User's input text
|
||||||
|
stream: Whether to stream the response
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects for progressive display
|
||||||
|
"""
|
||||||
|
if not self.selected_model:
|
||||||
|
raise ValueError("No model selected")
|
||||||
|
|
||||||
|
messages = self.build_api_messages(user_input)
|
||||||
|
tools = self.get_mcp_tools()
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
# Disable streaming when tools are present
|
||||||
|
stream = False
|
||||||
|
|
||||||
|
model_id = self.selected_model["id"]
|
||||||
|
if self.online_enabled:
|
||||||
|
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||||
|
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||||
|
|
||||||
|
transforms = ["middle-out"] if self.middle_out_enabled else None
|
||||||
|
max_tokens = None
|
||||||
|
if self.session_max_token > 0:
|
||||||
|
max_tokens = self.session_max_token
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
# Use async tool handling flow
|
||||||
|
async for chunk in self._send_with_tools_async(
|
||||||
|
messages=messages,
|
||||||
|
model_id=model_id,
|
||||||
|
tools=tools,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
elif stream:
|
||||||
|
# Use async streaming flow
|
||||||
|
async for chunk in self._stream_response_async(
|
||||||
|
messages=messages,
|
||||||
|
model_id=model_id,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
# Non-streaming request
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
if isinstance(response, ChatResponse):
|
||||||
|
# Yield single chunk with complete response
|
||||||
|
chunk = StreamChunk(
|
||||||
|
id="",
|
||||||
|
delta_content=response.content,
|
||||||
|
usage=response.usage,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _send_with_tools_async(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model_id: str,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
) -> AsyncIterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Async version of _send_with_tools for TUI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: API messages
|
||||||
|
model_id: Model ID
|
||||||
|
tools: Tool definitions
|
||||||
|
max_tokens: Max tokens
|
||||||
|
transforms: Transforms list
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects including tool call notifications
|
||||||
|
"""
|
||||||
|
max_loops = 5
|
||||||
|
loop_count = 0
|
||||||
|
api_messages = list(messages)
|
||||||
|
|
||||||
|
while loop_count < max_loops:
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=api_messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(response, ChatResponse):
|
||||||
|
raise ValueError("Expected ChatResponse")
|
||||||
|
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
if not tool_calls:
|
||||||
|
# Final response, yield it
|
||||||
|
chunk = StreamChunk(
|
||||||
|
id="",
|
||||||
|
delta_content=response.content,
|
||||||
|
usage=response.usage,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
|
||||||
|
# Yield notification about tool calls
|
||||||
|
tool_notification = f"\n🔧 AI requesting {len(tool_calls)} tool call(s)...\n"
|
||||||
|
yield StreamChunk(id="", delta_content=tool_notification, usage=None, error=None)
|
||||||
|
|
||||||
|
tool_results = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
try:
|
||||||
|
args = json.loads(tc.function.arguments)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.error(f"Failed to parse tool arguments: {e}")
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": tc.function.name,
|
||||||
|
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Yield tool call display
|
||||||
|
args_display = ", ".join(
|
||||||
|
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
||||||
|
for k, v in args.items()
|
||||||
|
)
|
||||||
|
tool_display = f" → {tc.function.name}({args_display})\n"
|
||||||
|
yield StreamChunk(id="", delta_content=tool_display, usage=None, error=None)
|
||||||
|
|
||||||
|
# Execute tool (await instead of asyncio.run)
|
||||||
|
result = await self.execute_tool(tc.function.name, args)
|
||||||
|
|
||||||
|
if "error" in result:
|
||||||
|
error_msg = f" ✗ Error: {result['error']}\n"
|
||||||
|
yield StreamChunk(id="", delta_content=error_msg, usage=None, error=None)
|
||||||
|
else:
|
||||||
|
success_msg = self._format_tool_success(tc.function.name, result)
|
||||||
|
yield StreamChunk(id="", delta_content=success_msg, usage=None, error=None)
|
||||||
|
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": tc.function.name,
|
||||||
|
"content": json.dumps(result),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add assistant message with tool calls
|
||||||
|
api_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response.content,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": tc.type,
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in tool_calls
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add tool results
|
||||||
|
api_messages.extend(tool_results)
|
||||||
|
loop_count += 1
|
||||||
|
|
||||||
|
# Max loops reached
|
||||||
|
yield StreamChunk(
|
||||||
|
id="",
|
||||||
|
delta_content="\n⚠️ Maximum tool call loops reached\n",
|
||||||
|
usage=None,
|
||||||
|
error="Max loops reached"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_tool_success(self, tool_name: str, result: Dict[str, Any]) -> str:
|
||||||
|
"""Format a success message for a tool call."""
|
||||||
|
if tool_name == "search_files":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
return f" ✓ Found {count} file(s)\n"
|
||||||
|
elif tool_name == "read_file":
|
||||||
|
size = result.get("size", 0)
|
||||||
|
truncated = " (truncated)" if result.get("truncated") else ""
|
||||||
|
return f" ✓ Read {size} bytes{truncated}\n"
|
||||||
|
elif tool_name == "list_directory":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
return f" ✓ Listed {count} item(s)\n"
|
||||||
|
elif tool_name == "inspect_database":
|
||||||
|
if "table" in result:
|
||||||
|
return f" ✓ Inspected table: {result['table']}\n"
|
||||||
|
else:
|
||||||
|
return f" ✓ Inspected database ({result.get('table_count', 0)} tables)\n"
|
||||||
|
elif tool_name == "search_database":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
return f" ✓ Found {count} match(es)\n"
|
||||||
|
elif tool_name == "query_database":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
return f" ✓ Query returned {count} row(s)\n"
|
||||||
|
else:
|
||||||
|
return " ✓ Success\n"
|
||||||
|
|
||||||
|
async def _stream_response_async(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model_id: str,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
) -> AsyncIterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Async version of _stream_response for TUI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: API messages
|
||||||
|
model_id: Model ID
|
||||||
|
max_tokens: Max tokens
|
||||||
|
transforms: Transforms
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects
|
||||||
|
"""
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, ChatResponse):
|
||||||
|
# Non-streaming response
|
||||||
|
chunk = StreamChunk(
|
||||||
|
id="",
|
||||||
|
delta_content=response.content,
|
||||||
|
usage=response.usage,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.error:
|
||||||
|
yield StreamChunk(id="", delta_content=None, usage=None, error=chunk.error)
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# ========== END ASYNC METHODS ==========
|
||||||
|
|
||||||
|
def add_to_history(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
response: str,
|
||||||
|
usage: Optional[UsageStats] = None,
|
||||||
|
cost: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add an exchange to the history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
response: Assistant response
|
||||||
|
usage: Usage statistics
|
||||||
|
cost: Cost if not in usage
|
||||||
|
"""
|
||||||
|
entry = HistoryEntry(
|
||||||
|
prompt=prompt,
|
||||||
|
response=response,
|
||||||
|
prompt_tokens=usage.prompt_tokens if usage else 0,
|
||||||
|
completion_tokens=usage.completion_tokens if usage else 0,
|
||||||
|
msg_cost=usage.total_cost_usd if usage and usage.total_cost_usd else cost,
|
||||||
|
timestamp=time.time(),
|
||||||
|
)
|
||||||
|
self.history.append(entry)
|
||||||
|
self.current_index = len(self.history) - 1
|
||||||
|
self.stats.add_usage(usage, cost)
|
||||||
|
|
||||||
|
def save_conversation(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Save the current conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name for the saved conversation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if saved successfully
|
||||||
|
"""
|
||||||
|
if not self.history:
|
||||||
|
return False
|
||||||
|
|
||||||
|
data = [e.to_dict() for e in self.history]
|
||||||
|
self.db.save_conversation(name, data)
|
||||||
|
self.logger.info(f"Saved conversation: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_conversation(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Load a saved conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the conversation to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if loaded successfully
|
||||||
|
"""
|
||||||
|
data = self.db.load_conversation(name)
|
||||||
|
if not data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.history.clear()
|
||||||
|
for entry_dict in data:
|
||||||
|
self.history.append(HistoryEntry(
|
||||||
|
prompt=entry_dict.get("prompt", ""),
|
||||||
|
response=entry_dict.get("response", ""),
|
||||||
|
prompt_tokens=entry_dict.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=entry_dict.get("completion_tokens", 0),
|
||||||
|
msg_cost=entry_dict.get("msg_cost", 0.0),
|
||||||
|
))
|
||||||
|
|
||||||
|
self.current_index = len(self.history) - 1
|
||||||
|
self.memory_start_index = 0
|
||||||
|
self.stats = SessionStats() # Reset stats for loaded conversation
|
||||||
|
self.logger.info(f"Loaded conversation: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the session state."""
|
||||||
|
self.history.clear()
|
||||||
|
self.stats = SessionStats()
|
||||||
|
self.system_prompt = ""
|
||||||
|
self.memory_start_index = 0
|
||||||
|
self.current_index = 0
|
||||||
|
self.logger.info("Session reset")
|
||||||
|
|
||||||
|
def check_warnings(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Check for cost and credit warnings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of warning messages
|
||||||
|
"""
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
# Check last message cost
|
||||||
|
if self.history:
|
||||||
|
last_cost = self.history[-1].msg_cost
|
||||||
|
threshold = self.settings.cost_warning_threshold
|
||||||
|
if last_cost > threshold:
|
||||||
|
warnings.append(
|
||||||
|
f"High cost: ${last_cost:.4f} exceeds threshold ${threshold:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check credits
|
||||||
|
credits = self.client.get_credits()
|
||||||
|
if credits:
|
||||||
|
left = credits.get("credits_left", 0)
|
||||||
|
total = credits.get("total_credits", 0)
|
||||||
|
|
||||||
|
if left < LOW_CREDIT_AMOUNT:
|
||||||
|
warnings.append(f"Low credits: ${left:.2f} remaining!")
|
||||||
|
elif total > 0 and left < total * LOW_CREDIT_RATIO:
|
||||||
|
warnings.append(f"Credits low: less than 10% remaining (${left:.2f})")
|
||||||
|
|
||||||
|
return warnings
|
||||||
28
oai/mcp/__init__.py
Normal file
28
oai/mcp/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
Model Context Protocol (MCP) integration for oAI.
|
||||||
|
|
||||||
|
This package provides filesystem and database access capabilities
|
||||||
|
through the MCP standard, allowing AI models to interact with
|
||||||
|
local files and SQLite databases safely.
|
||||||
|
|
||||||
|
Key components:
|
||||||
|
- MCPManager: High-level manager for MCP operations
|
||||||
|
- MCPFilesystemServer: Filesystem and database access implementation
|
||||||
|
- GitignoreParser: Pattern matching for .gitignore support
|
||||||
|
- SQLiteQueryValidator: Query safety validation
|
||||||
|
- CrossPlatformMCPConfig: OS-specific configuration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.mcp.manager import MCPManager
|
||||||
|
from oai.mcp.server import MCPFilesystemServer
|
||||||
|
from oai.mcp.gitignore import GitignoreParser
|
||||||
|
from oai.mcp.validators import SQLiteQueryValidator
|
||||||
|
from oai.mcp.platform import CrossPlatformMCPConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MCPManager",
|
||||||
|
"MCPFilesystemServer",
|
||||||
|
"GitignoreParser",
|
||||||
|
"SQLiteQueryValidator",
|
||||||
|
"CrossPlatformMCPConfig",
|
||||||
|
]
|
||||||
166
oai/mcp/gitignore.py
Normal file
166
oai/mcp/gitignore.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
Gitignore pattern parsing for oAI MCP.
|
||||||
|
|
||||||
|
This module implements .gitignore pattern matching to filter files
|
||||||
|
during MCP filesystem operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class GitignoreParser:
|
||||||
|
"""
|
||||||
|
Parse and apply .gitignore patterns.
|
||||||
|
|
||||||
|
Supports standard gitignore syntax including:
|
||||||
|
- Wildcards (*) and double wildcards (**)
|
||||||
|
- Directory-only patterns (ending with /)
|
||||||
|
- Negation patterns (starting with !)
|
||||||
|
- Comments (lines starting with #)
|
||||||
|
|
||||||
|
Patterns are applied relative to the directory containing
|
||||||
|
the .gitignore file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize an empty pattern collection."""
|
||||||
|
# List of (pattern, is_negation, source_dir)
|
||||||
|
self.patterns: List[Tuple[str, bool, Path]] = []
|
||||||
|
|
||||||
|
def add_gitignore(self, gitignore_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
Parse and add patterns from a .gitignore file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gitignore_path: Path to the .gitignore file
|
||||||
|
"""
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
if not gitignore_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
source_dir = gitignore_path.parent
|
||||||
|
|
||||||
|
with open(gitignore_path, "r", encoding="utf-8") as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
line = line.rstrip("\n\r")
|
||||||
|
|
||||||
|
# Skip empty lines and comments
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for negation pattern
|
||||||
|
is_negation = line.startswith("!")
|
||||||
|
if is_negation:
|
||||||
|
line = line[1:]
|
||||||
|
|
||||||
|
# Remove leading slash (make relative to gitignore location)
|
||||||
|
if line.startswith("/"):
|
||||||
|
line = line[1:]
|
||||||
|
|
||||||
|
self.patterns.append((line, is_negation, source_dir))
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Loaded {len(self.patterns)} patterns from {gitignore_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error reading {gitignore_path}: {e}")
|
||||||
|
|
||||||
|
def should_ignore(self, path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a path should be ignored based on gitignore patterns.
|
||||||
|
|
||||||
|
Patterns are evaluated in order, with later patterns overriding
|
||||||
|
earlier ones. Negation patterns (starting with !) un-ignore
|
||||||
|
previously matched paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the path should be ignored
|
||||||
|
"""
|
||||||
|
if not self.patterns:
|
||||||
|
return False
|
||||||
|
|
||||||
|
ignored = False
|
||||||
|
|
||||||
|
for pattern, is_negation, source_dir in self.patterns:
|
||||||
|
# Only apply pattern if path is under the source directory
|
||||||
|
try:
|
||||||
|
rel_path = path.relative_to(source_dir)
|
||||||
|
except ValueError:
|
||||||
|
# Path is not relative to this gitignore's directory
|
||||||
|
continue
|
||||||
|
|
||||||
|
rel_path_str = str(rel_path)
|
||||||
|
|
||||||
|
# Check if pattern matches
|
||||||
|
if self._match_pattern(pattern, rel_path_str, path.is_dir()):
|
||||||
|
if is_negation:
|
||||||
|
ignored = False # Negation patterns un-ignore
|
||||||
|
else:
|
||||||
|
ignored = True
|
||||||
|
|
||||||
|
return ignored
|
||||||
|
|
||||||
|
def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool:
|
||||||
|
"""
|
||||||
|
Match a gitignore pattern against a path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: The gitignore pattern
|
||||||
|
path: The relative path string to match
|
||||||
|
is_dir: Whether the path is a directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the pattern matches
|
||||||
|
"""
|
||||||
|
# Directory-only pattern (ends with /)
|
||||||
|
if pattern.endswith("/"):
|
||||||
|
if not is_dir:
|
||||||
|
return False
|
||||||
|
pattern = pattern[:-1]
|
||||||
|
|
||||||
|
# Handle ** patterns (matches any number of directories)
|
||||||
|
if "**" in pattern:
|
||||||
|
pattern_parts = pattern.split("**")
|
||||||
|
if len(pattern_parts) == 2:
|
||||||
|
prefix, suffix = pattern_parts
|
||||||
|
|
||||||
|
# Match if path starts with prefix and ends with suffix
|
||||||
|
if prefix:
|
||||||
|
if not path.startswith(prefix.rstrip("/")):
|
||||||
|
return False
|
||||||
|
if suffix:
|
||||||
|
suffix = suffix.lstrip("/")
|
||||||
|
if not (path.endswith(suffix) or f"/{suffix}" in path):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Direct match using fnmatch
|
||||||
|
if fnmatch.fnmatch(path, pattern):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Match as subdirectory pattern (pattern without / matches in any directory)
|
||||||
|
if "/" not in pattern:
|
||||||
|
parts = path.split("/")
|
||||||
|
if any(fnmatch.fnmatch(part, pattern) for part in parts):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all loaded patterns."""
|
||||||
|
self.patterns = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pattern_count(self) -> int:
|
||||||
|
"""Get the number of loaded patterns."""
|
||||||
|
return len(self.patterns)
|
||||||
1365
oai/mcp/manager.py
Normal file
1365
oai/mcp/manager.py
Normal file
File diff suppressed because it is too large
Load Diff
228
oai/mcp/platform.py
Normal file
228
oai/mcp/platform.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""
|
||||||
|
Cross-platform MCP configuration for oAI.
|
||||||
|
|
||||||
|
This module handles OS-specific configuration, path handling,
|
||||||
|
and security checks for the MCP filesystem server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from oai.constants import SYSTEM_DIRS_BLACKLIST
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class CrossPlatformMCPConfig:
|
||||||
|
"""
|
||||||
|
Handle OS-specific MCP configuration.
|
||||||
|
|
||||||
|
Provides methods for path normalization, security validation,
|
||||||
|
and OS-specific default directories.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
system: Operating system name
|
||||||
|
is_macos: Whether running on macOS
|
||||||
|
is_linux: Whether running on Linux
|
||||||
|
is_windows: Whether running on Windows
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize platform detection."""
|
||||||
|
self.system = platform.system()
|
||||||
|
self.is_macos = self.system == "Darwin"
|
||||||
|
self.is_linux = self.system == "Linux"
|
||||||
|
self.is_windows = self.system == "Windows"
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
logger.info(f"Detected OS: {self.system}")
|
||||||
|
|
||||||
|
def get_default_allowed_dirs(self) -> List[Path]:
|
||||||
|
"""
|
||||||
|
Get safe default directories for the current OS.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of default directories that are safe to access
|
||||||
|
"""
|
||||||
|
home = Path.home()
|
||||||
|
|
||||||
|
if self.is_macos:
|
||||||
|
return [
|
||||||
|
home / "Documents",
|
||||||
|
home / "Desktop",
|
||||||
|
home / "Downloads",
|
||||||
|
]
|
||||||
|
|
||||||
|
elif self.is_linux:
|
||||||
|
dirs = [home / "Documents"]
|
||||||
|
|
||||||
|
# Try to get XDG directories
|
||||||
|
try:
|
||||||
|
for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]:
|
||||||
|
result = subprocess.run(
|
||||||
|
["xdg-user-dir", xdg_dir],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=1
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
dir_path = Path(result.stdout.strip())
|
||||||
|
if dir_path.exists():
|
||||||
|
dirs.append(dir_path)
|
||||||
|
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||||
|
# Fallback to standard locations
|
||||||
|
dirs.extend([
|
||||||
|
home / "Desktop",
|
||||||
|
home / "Downloads",
|
||||||
|
])
|
||||||
|
|
||||||
|
return list(set(dirs))
|
||||||
|
|
||||||
|
elif self.is_windows:
|
||||||
|
return [
|
||||||
|
home / "Documents",
|
||||||
|
home / "Desktop",
|
||||||
|
home / "Downloads",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Fallback for unknown OS
|
||||||
|
return [home]
|
||||||
|
|
||||||
|
def get_python_command(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the Python executable path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the Python executable
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
return sys.executable
|
||||||
|
|
||||||
|
def get_filesystem_warning(self) -> str:
|
||||||
|
"""
|
||||||
|
Get OS-specific security warning message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Warning message for the current OS
|
||||||
|
"""
|
||||||
|
if self.is_macos:
|
||||||
|
return """
|
||||||
|
Note: macOS Security
|
||||||
|
The Filesystem MCP server needs access to your selected folder.
|
||||||
|
You may see a security prompt - click 'Allow' to proceed.
|
||||||
|
(System Settings > Privacy & Security > Files and Folders)
|
||||||
|
"""
|
||||||
|
elif self.is_linux:
|
||||||
|
return """
|
||||||
|
Note: Linux Security
|
||||||
|
The Filesystem MCP server will access your selected folder.
|
||||||
|
Ensure oAI has appropriate file permissions.
|
||||||
|
"""
|
||||||
|
elif self.is_windows:
|
||||||
|
return """
|
||||||
|
Note: Windows Security
|
||||||
|
The Filesystem MCP server will access your selected folder.
|
||||||
|
You may need to grant file access permissions.
|
||||||
|
"""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def normalize_path(self, path: str) -> Path:
|
||||||
|
"""
|
||||||
|
Normalize a path for the current OS.
|
||||||
|
|
||||||
|
Expands user directory (~) and resolves to absolute path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path string to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized absolute Path
|
||||||
|
"""
|
||||||
|
return Path(os.path.expanduser(path)).resolve()
|
||||||
|
|
||||||
|
def is_system_directory(self, path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a path is a protected system directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the path is a system directory
|
||||||
|
"""
|
||||||
|
path_str = str(path)
|
||||||
|
for blocked in SYSTEM_DIRS_BLACKLIST:
|
||||||
|
if path_str.startswith(blocked):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a path is within allowed directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requested_path: Path being requested
|
||||||
|
allowed_dirs: List of allowed parent directories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the path is within an allowed directory
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
requested = requested_path.resolve()
|
||||||
|
|
||||||
|
for allowed in allowed_dirs:
|
||||||
|
try:
|
||||||
|
allowed_resolved = allowed.resolve()
|
||||||
|
requested.relative_to(allowed_resolved)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_folder_stats(self, folder: Path) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get statistics for a folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder: Path to the folder
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with folder statistics:
|
||||||
|
- exists: Whether the folder exists
|
||||||
|
- file_count: Number of files (if exists)
|
||||||
|
- total_size: Total size in bytes (if exists)
|
||||||
|
- size_mb: Size in megabytes (if exists)
|
||||||
|
- error: Error message (if any)
|
||||||
|
"""
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not folder.exists() or not folder.is_dir():
|
||||||
|
return {"exists": False}
|
||||||
|
|
||||||
|
file_count = 0
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
for item in folder.rglob("*"):
|
||||||
|
if item.is_file():
|
||||||
|
file_count += 1
|
||||||
|
try:
|
||||||
|
total_size += item.stat().st_size
|
||||||
|
except (OSError, PermissionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
"exists": True,
|
||||||
|
"file_count": file_count,
|
||||||
|
"total_size": total_size,
|
||||||
|
"size_mb": total_size / (1024 * 1024),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting folder stats for {folder}: {e}")
|
||||||
|
return {"exists": False, "error": str(e)}
|
||||||
1368
oai/mcp/server.py
Normal file
1368
oai/mcp/server.py
Normal file
File diff suppressed because it is too large
Load Diff
123
oai/mcp/validators.py
Normal file
123
oai/mcp/validators.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
Query validation for oAI MCP database operations.
|
||||||
|
|
||||||
|
This module provides safety validation for SQL queries to ensure
|
||||||
|
only read-only operations are executed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from oai.constants import DANGEROUS_SQL_KEYWORDS
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteQueryValidator:
|
||||||
|
"""
|
||||||
|
Validate SQLite queries for read-only safety.
|
||||||
|
|
||||||
|
Ensures that only SELECT queries (including CTEs) are allowed
|
||||||
|
and blocks potentially dangerous operations like INSERT, UPDATE,
|
||||||
|
DELETE, DROP, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_safe_query(query: str) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Validate that a query is a safe read-only SELECT.
|
||||||
|
|
||||||
|
The validation:
|
||||||
|
1. Checks that query starts with SELECT or WITH
|
||||||
|
2. Strips string literals before checking for dangerous keywords
|
||||||
|
3. Blocks any dangerous keywords outside of string literals
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: SQL query string to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_safe, error_message)
|
||||||
|
- is_safe: True if the query is safe to execute
|
||||||
|
- error_message: Description of why query is unsafe (empty if safe)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> SQLiteQueryValidator.is_safe_query("SELECT * FROM users")
|
||||||
|
(True, "")
|
||||||
|
>>> SQLiteQueryValidator.is_safe_query("DELETE FROM users")
|
||||||
|
(False, "Only SELECT queries are allowed...")
|
||||||
|
>>> SQLiteQueryValidator.is_safe_query("SELECT 'DELETE' FROM users")
|
||||||
|
(True, "") # 'DELETE' is inside a string literal
|
||||||
|
"""
|
||||||
|
query_upper = query.strip().upper()
|
||||||
|
|
||||||
|
# Must start with SELECT or WITH (for CTEs)
|
||||||
|
if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")):
|
||||||
|
return False, "Only SELECT queries are allowed (including WITH/CTE)"
|
||||||
|
|
||||||
|
# Remove string literals before checking for dangerous keywords
|
||||||
|
# This prevents false positives when keywords appear in data
|
||||||
|
query_no_strings = re.sub(r"'[^']*'", "", query_upper)
|
||||||
|
query_no_strings = re.sub(r'"[^"]*"', "", query_no_strings)
|
||||||
|
|
||||||
|
# Check for dangerous keywords outside of quotes
|
||||||
|
for keyword in DANGEROUS_SQL_KEYWORDS:
|
||||||
|
if re.search(r"\b" + keyword + r"\b", query_no_strings):
|
||||||
|
return False, f"Keyword '{keyword}' not allowed in read-only mode"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sanitize_table_name(table_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize a table name to prevent SQL injection.
|
||||||
|
|
||||||
|
Only allows alphanumeric characters and underscores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: Table name to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized table name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If table name contains invalid characters
|
||||||
|
"""
|
||||||
|
# Remove any characters that aren't alphanumeric or underscore
|
||||||
|
sanitized = re.sub(r"[^\w]", "", table_name)
|
||||||
|
|
||||||
|
if not sanitized:
|
||||||
|
raise ValueError("Table name cannot be empty after sanitization")
|
||||||
|
|
||||||
|
if sanitized != table_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Table name contains invalid characters: {table_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sanitize_column_name(column_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize a column name to prevent SQL injection.
|
||||||
|
|
||||||
|
Only allows alphanumeric characters and underscores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name: Column name to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized column name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If column name contains invalid characters
|
||||||
|
"""
|
||||||
|
# Remove any characters that aren't alphanumeric or underscore
|
||||||
|
sanitized = re.sub(r"[^\w]", "", column_name)
|
||||||
|
|
||||||
|
if not sanitized:
|
||||||
|
raise ValueError("Column name cannot be empty after sanitization")
|
||||||
|
|
||||||
|
if sanitized != column_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Column name contains invalid characters: {column_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sanitized
|
||||||
32
oai/providers/__init__.py
Normal file
32
oai/providers/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
Provider abstraction for oAI.
|
||||||
|
|
||||||
|
This module provides a unified interface for AI model providers,
|
||||||
|
enabling easy extension to support additional providers beyond OpenRouter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.providers.base import (
|
||||||
|
AIProvider,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResponse,
|
||||||
|
ToolCall,
|
||||||
|
ToolFunction,
|
||||||
|
UsageStats,
|
||||||
|
ModelInfo,
|
||||||
|
ProviderCapabilities,
|
||||||
|
)
|
||||||
|
from oai.providers.openrouter import OpenRouterProvider
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Base classes and types
|
||||||
|
"AIProvider",
|
||||||
|
"ChatMessage",
|
||||||
|
"ChatResponse",
|
||||||
|
"ToolCall",
|
||||||
|
"ToolFunction",
|
||||||
|
"UsageStats",
|
||||||
|
"ModelInfo",
|
||||||
|
"ProviderCapabilities",
|
||||||
|
# Provider implementations
|
||||||
|
"OpenRouterProvider",
|
||||||
|
]
|
||||||
413
oai/providers/base.py
Normal file
413
oai/providers/base.py
Normal file
@@ -0,0 +1,413 @@
|
|||||||
|
"""
|
||||||
|
Abstract base provider for AI model integration.
|
||||||
|
|
||||||
|
This module defines the interface that all AI providers must implement,
|
||||||
|
along with common data structures for requests and responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRole(str, Enum):
|
||||||
|
"""Message roles in a conversation."""
|
||||||
|
|
||||||
|
SYSTEM = "system"
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
TOOL = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolFunction:
|
||||||
|
"""
|
||||||
|
Represents a function within a tool call.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: The function name
|
||||||
|
arguments: JSON string of function arguments
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCall:
|
||||||
|
"""
|
||||||
|
Represents a tool/function call requested by the model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique identifier for this tool call
|
||||||
|
type: Type of tool call (usually "function")
|
||||||
|
function: The function being called
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
function: ToolFunction
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageStats:
|
||||||
|
"""
|
||||||
|
Token usage statistics from an API response.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prompt_tokens: Number of tokens in the prompt
|
||||||
|
completion_tokens: Number of tokens in the completion
|
||||||
|
total_tokens: Total tokens used
|
||||||
|
total_cost_usd: Cost in USD (if available from API)
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
total_cost_usd: Optional[float] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_tokens(self) -> int:
|
||||||
|
"""Alias for prompt_tokens."""
|
||||||
|
return self.prompt_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_tokens(self) -> int:
|
||||||
|
"""Alias for completion_tokens."""
|
||||||
|
return self.completion_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
"""
|
||||||
|
A single message in a chat conversation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
role: The role of the message sender
|
||||||
|
content: Message content (text or structured content blocks)
|
||||||
|
name: Optional name for the sender
|
||||||
|
tool_calls: List of tool calls (for assistant messages)
|
||||||
|
tool_call_id: Tool call ID this message responds to (for tool messages)
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: str
|
||||||
|
content: Union[str, List[Dict[str, Any]], None] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[ToolCall]] = None
|
||||||
|
tool_call_id: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format for API requests."""
|
||||||
|
result: Dict[str, Any] = {"role": self.role}
|
||||||
|
|
||||||
|
if self.content is not None:
|
||||||
|
result["content"] = self.content
|
||||||
|
|
||||||
|
if self.name:
|
||||||
|
result["name"] = self.name
|
||||||
|
|
||||||
|
if self.tool_calls:
|
||||||
|
result["tool_calls"] = [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": tc.type,
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in self.tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.tool_call_id:
|
||||||
|
result["tool_call_id"] = self.tool_call_id
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatResponseChoice:
|
||||||
|
"""
|
||||||
|
A single choice in a chat response.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
index: Index of this choice
|
||||||
|
message: The response message
|
||||||
|
finish_reason: Why the response ended
|
||||||
|
"""
|
||||||
|
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatResponse:
|
||||||
|
"""
|
||||||
|
Response from a chat completion request.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique identifier for this response
|
||||||
|
choices: List of response choices
|
||||||
|
usage: Token usage statistics
|
||||||
|
model: Model that generated this response
|
||||||
|
created: Unix timestamp of creation
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
choices: List[ChatResponseChoice]
|
||||||
|
usage: Optional[UsageStats] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
created: Optional[int] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message(self) -> Optional[ChatMessage]:
|
||||||
|
"""Get the first choice's message."""
|
||||||
|
if self.choices:
|
||||||
|
return self.choices[0].message
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> Optional[str]:
|
||||||
|
"""Get the text content of the first choice."""
|
||||||
|
msg = self.message
|
||||||
|
if msg and isinstance(msg.content, str):
|
||||||
|
return msg.content
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_calls(self) -> Optional[List[ToolCall]]:
|
||||||
|
"""Get tool calls from the first choice."""
|
||||||
|
msg = self.message
|
||||||
|
if msg:
|
||||||
|
return msg.tool_calls
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamChunk:
|
||||||
|
"""
|
||||||
|
A single chunk from a streaming response.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Response ID
|
||||||
|
delta_content: New content in this chunk
|
||||||
|
finish_reason: Finish reason (if this is the last chunk)
|
||||||
|
usage: Usage stats (usually in the last chunk)
|
||||||
|
error: Error message if something went wrong
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
delta_content: Optional[str] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
usage: Optional[UsageStats] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInfo:
|
||||||
|
"""
|
||||||
|
Information about an AI model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique model identifier
|
||||||
|
name: Display name
|
||||||
|
description: Model description
|
||||||
|
context_length: Maximum context window size
|
||||||
|
pricing: Pricing info (input/output per million tokens)
|
||||||
|
supported_parameters: List of supported API parameters
|
||||||
|
input_modalities: Supported input types (text, image, etc.)
|
||||||
|
output_modalities: Supported output types
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
context_length: int = 0
|
||||||
|
pricing: Dict[str, float] = field(default_factory=dict)
|
||||||
|
supported_parameters: List[str] = field(default_factory=list)
|
||||||
|
input_modalities: List[str] = field(default_factory=lambda: ["text"])
|
||||||
|
output_modalities: List[str] = field(default_factory=lambda: ["text"])
|
||||||
|
|
||||||
|
def supports_images(self) -> bool:
|
||||||
|
"""Check if model supports image input."""
|
||||||
|
return "image" in self.input_modalities
|
||||||
|
|
||||||
|
def supports_tools(self) -> bool:
|
||||||
|
"""Check if model supports function calling/tools."""
|
||||||
|
return "tools" in self.supported_parameters or "functions" in self.supported_parameters
|
||||||
|
|
||||||
|
def supports_streaming(self) -> bool:
|
||||||
|
"""Check if model supports streaming responses."""
|
||||||
|
return "stream" in self.supported_parameters
|
||||||
|
|
||||||
|
def supports_online(self) -> bool:
|
||||||
|
"""Check if model supports web search (online mode)."""
|
||||||
|
return self.supports_tools()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderCapabilities:
|
||||||
|
"""
|
||||||
|
Capabilities supported by a provider.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
streaming: Provider supports streaming responses
|
||||||
|
tools: Provider supports function calling
|
||||||
|
images: Provider supports image inputs
|
||||||
|
online: Provider supports web search
|
||||||
|
max_context: Maximum context length across all models
|
||||||
|
"""
|
||||||
|
|
||||||
|
streaming: bool = True
|
||||||
|
tools: bool = True
|
||||||
|
images: bool = True
|
||||||
|
online: bool = False
|
||||||
|
max_context: int = 128000
|
||||||
|
|
||||||
|
|
||||||
|
class AIProvider(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for AI model providers.
|
||||||
|
|
||||||
|
All provider implementations must inherit from this class
|
||||||
|
and implement the required abstract methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication
|
||||||
|
base_url: Optional custom base URL for the API
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get the provider name."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def capabilities(self) -> ProviderCapabilities:
|
||||||
|
"""Get provider capabilities."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_models(self) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Fetch available models from the provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available models with their info
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model information or None if not found
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send a chat completion request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID to use
|
||||||
|
messages: List of chat messages
|
||||||
|
stream: Whether to stream the response
|
||||||
|
max_tokens: Maximum tokens in response
|
||||||
|
temperature: Sampling temperature
|
||||||
|
tools: List of tool definitions for function calling
|
||||||
|
tool_choice: How to handle tool selection ("auto", "none", etc.)
|
||||||
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat_async(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send an async chat completion request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID to use
|
||||||
|
messages: List of chat messages
|
||||||
|
stream: Whether to stream the response
|
||||||
|
max_tokens: Maximum tokens in response
|
||||||
|
temperature: Sampling temperature
|
||||||
|
tools: List of tool definitions for function calling
|
||||||
|
tool_choice: How to handle tool selection
|
||||||
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse for non-streaming, AsyncIterator[StreamChunk] for streaming
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get account credit/balance information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with credit info or None if not supported
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validate_api_key(self) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that the API key is valid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if API key is valid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.list_models()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
630
oai/providers/openrouter.py
Normal file
630
oai/providers/openrouter.py
Normal file
@@ -0,0 +1,630 @@
|
|||||||
|
"""
|
||||||
|
OpenRouter provider implementation.
|
||||||
|
|
||||||
|
This module implements the AIProvider interface for OpenRouter,
|
||||||
|
supporting chat completions, streaming, and function calling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from openrouter import OpenRouter
|
||||||
|
|
||||||
|
from oai.constants import APP_NAME, APP_URL, DEFAULT_BASE_URL
|
||||||
|
from oai.providers.base import (
|
||||||
|
AIProvider,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResponse,
|
||||||
|
ChatResponseChoice,
|
||||||
|
ModelInfo,
|
||||||
|
ProviderCapabilities,
|
||||||
|
StreamChunk,
|
||||||
|
ToolCall,
|
||||||
|
ToolFunction,
|
||||||
|
UsageStats,
|
||||||
|
)
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterProvider(AIProvider):
|
||||||
|
"""
|
||||||
|
OpenRouter API provider implementation.
|
||||||
|
|
||||||
|
Provides access to multiple AI models through OpenRouter's unified API,
|
||||||
|
supporting chat completions, streaming responses, and function calling.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
client: The underlying OpenRouter client
|
||||||
|
_models_cache: Cached list of available models
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
app_name: str = APP_NAME,
|
||||||
|
app_url: str = APP_URL,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the OpenRouter provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenRouter API key
|
||||||
|
base_url: Optional custom base URL
|
||||||
|
app_name: Application name for API headers
|
||||||
|
app_url: Application URL for API headers
|
||||||
|
"""
|
||||||
|
super().__init__(api_key, base_url or DEFAULT_BASE_URL)
|
||||||
|
self.app_name = app_name
|
||||||
|
self.app_url = app_url
|
||||||
|
self.client = OpenRouter(api_key=api_key)
|
||||||
|
self._models_cache: Optional[List[ModelInfo]] = None
|
||||||
|
self._raw_models_cache: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.logger.info(f"OpenRouter provider initialized with base URL: {self.base_url}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get the provider name."""
|
||||||
|
return "OpenRouter"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capabilities(self) -> ProviderCapabilities:
|
||||||
|
"""Get provider capabilities."""
|
||||||
|
return ProviderCapabilities(
|
||||||
|
streaming=True,
|
||||||
|
tools=True,
|
||||||
|
images=True,
|
||||||
|
online=True,
|
||||||
|
max_context=2000000, # Claude models support up to 200k
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_headers(self) -> Dict[str, str]:
|
||||||
|
"""Get standard HTTP headers for API requests."""
|
||||||
|
headers = {
|
||||||
|
"HTTP-Referer": self.app_url,
|
||||||
|
"X-Title": self.app_name,
|
||||||
|
}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
|
||||||
|
"""
|
||||||
|
Parse raw model data into ModelInfo.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_data: Raw model data from API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed ModelInfo object
|
||||||
|
"""
|
||||||
|
architecture = model_data.get("architecture", {})
|
||||||
|
pricing_data = model_data.get("pricing", {})
|
||||||
|
|
||||||
|
# Parse pricing (convert from string to float if needed)
|
||||||
|
pricing = {}
|
||||||
|
for key in ["prompt", "completion"]:
|
||||||
|
value = pricing_data.get(key)
|
||||||
|
if value is not None:
|
||||||
|
try:
|
||||||
|
# Convert from per-token to per-million-tokens
|
||||||
|
pricing[key] = float(value) * 1_000_000
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pricing[key] = 0.0
|
||||||
|
|
||||||
|
return ModelInfo(
|
||||||
|
id=model_data.get("id", ""),
|
||||||
|
name=model_data.get("name", model_data.get("id", "")),
|
||||||
|
description=model_data.get("description", ""),
|
||||||
|
context_length=model_data.get("context_length", 0),
|
||||||
|
pricing=pricing,
|
||||||
|
supported_parameters=model_data.get("supported_parameters", []),
|
||||||
|
input_modalities=architecture.get("input_modalities", ["text"]),
|
||||||
|
output_modalities=architecture.get("output_modalities", ["text"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Fetch available models from OpenRouter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_text_only: If True, exclude video-only models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available models
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If API request fails
|
||||||
|
"""
|
||||||
|
if self._models_cache is not None:
|
||||||
|
return self._models_cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/models",
|
||||||
|
headers=self._get_headers(),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
raw_models = response.json().get("data", [])
|
||||||
|
self._raw_models_cache = raw_models
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for model_data in raw_models:
|
||||||
|
# Optionally filter out video-only models
|
||||||
|
if filter_text_only:
|
||||||
|
modalities = model_data.get("modalities", [])
|
||||||
|
if modalities and "video" in modalities and "text" not in modalities:
|
||||||
|
continue
|
||||||
|
|
||||||
|
models.append(self._parse_model(model_data))
|
||||||
|
|
||||||
|
self._models_cache = models
|
||||||
|
self.logger.info(f"Fetched {len(models)} models from OpenRouter")
|
||||||
|
return models
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
self.logger.error(f"Failed to fetch models: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_raw_models(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data as returned by the API.
|
||||||
|
|
||||||
|
Useful for accessing provider-specific fields not in ModelInfo.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of raw model dictionaries
|
||||||
|
"""
|
||||||
|
if self._raw_models_cache is None:
|
||||||
|
self.list_models()
|
||||||
|
return self._raw_models_cache or []
|
||||||
|
|
||||||
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model information or None if not found
|
||||||
|
"""
|
||||||
|
models = self.list_models()
|
||||||
|
for model in models:
|
||||||
|
if model.id == model_id:
|
||||||
|
return model
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data for a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw model dictionary or None if not found
|
||||||
|
"""
|
||||||
|
raw_models = self.get_raw_models()
|
||||||
|
for model in raw_models:
|
||||||
|
if model.get("id") == model_id:
|
||||||
|
return model
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Convert ChatMessage objects to API format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of ChatMessage objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of message dictionaries for the API
|
||||||
|
"""
|
||||||
|
return [msg.to_dict() for msg in messages]
|
||||||
|
|
||||||
|
def _parse_usage(self, usage_data: Any) -> Optional[UsageStats]:
|
||||||
|
"""
|
||||||
|
Parse usage data from API response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage_data: Raw usage data from API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed UsageStats or None
|
||||||
|
"""
|
||||||
|
if not usage_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle both attribute and dict access
|
||||||
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
total_cost = None
|
||||||
|
|
||||||
|
if hasattr(usage_data, "prompt_tokens"):
|
||||||
|
prompt_tokens = getattr(usage_data, "prompt_tokens", 0) or 0
|
||||||
|
elif isinstance(usage_data, dict):
|
||||||
|
prompt_tokens = usage_data.get("prompt_tokens", 0) or 0
|
||||||
|
|
||||||
|
if hasattr(usage_data, "completion_tokens"):
|
||||||
|
completion_tokens = getattr(usage_data, "completion_tokens", 0) or 0
|
||||||
|
elif isinstance(usage_data, dict):
|
||||||
|
completion_tokens = usage_data.get("completion_tokens", 0) or 0
|
||||||
|
|
||||||
|
# Try alternative naming (input_tokens/output_tokens)
|
||||||
|
if prompt_tokens == 0:
|
||||||
|
if hasattr(usage_data, "input_tokens"):
|
||||||
|
prompt_tokens = getattr(usage_data, "input_tokens", 0) or 0
|
||||||
|
elif isinstance(usage_data, dict):
|
||||||
|
prompt_tokens = usage_data.get("input_tokens", 0) or 0
|
||||||
|
|
||||||
|
if completion_tokens == 0:
|
||||||
|
if hasattr(usage_data, "output_tokens"):
|
||||||
|
completion_tokens = getattr(usage_data, "output_tokens", 0) or 0
|
||||||
|
elif isinstance(usage_data, dict):
|
||||||
|
completion_tokens = usage_data.get("output_tokens", 0) or 0
|
||||||
|
|
||||||
|
# Get cost if available
|
||||||
|
# OpenRouter returns cost in different places:
|
||||||
|
# 1. As 'total_cost_usd' in usage object (rare)
|
||||||
|
# 2. As 'usage' at root level (common - this is the dollar amount)
|
||||||
|
total_cost = None
|
||||||
|
if hasattr(usage_data, "total_cost_usd"):
|
||||||
|
total_cost = getattr(usage_data, "total_cost_usd", None)
|
||||||
|
elif hasattr(usage_data, "usage"):
|
||||||
|
# OpenRouter puts cost as 'usage' field (dollar amount)
|
||||||
|
total_cost = getattr(usage_data, "usage", None)
|
||||||
|
elif isinstance(usage_data, dict):
|
||||||
|
total_cost = usage_data.get("total_cost_usd") or usage_data.get("usage")
|
||||||
|
|
||||||
|
return UsageStats(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
total_cost_usd=float(total_cost) if total_cost else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_tool_calls(self, tool_calls_data: Any) -> Optional[List[ToolCall]]:
|
||||||
|
"""
|
||||||
|
Parse tool calls from API response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_calls_data: Raw tool calls data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ToolCall objects or None
|
||||||
|
"""
|
||||||
|
if not tool_calls_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
for tc in tool_calls_data:
|
||||||
|
# Handle both attribute and dict access
|
||||||
|
if hasattr(tc, "id"):
|
||||||
|
tc_id = tc.id
|
||||||
|
tc_type = getattr(tc, "type", "function")
|
||||||
|
func = tc.function
|
||||||
|
func_name = func.name
|
||||||
|
func_args = func.arguments
|
||||||
|
else:
|
||||||
|
tc_id = tc.get("id", "")
|
||||||
|
tc_type = tc.get("type", "function")
|
||||||
|
func = tc.get("function", {})
|
||||||
|
func_name = func.get("name", "")
|
||||||
|
func_args = func.get("arguments", "{}")
|
||||||
|
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=tc_id,
|
||||||
|
type=tc_type,
|
||||||
|
function=ToolFunction(name=func_name, arguments=func_args),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_calls if tool_calls else None
|
||||||
|
|
||||||
|
def _parse_response(self, response: Any) -> ChatResponse:
|
||||||
|
"""
|
||||||
|
Parse API response into ChatResponse.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Raw API response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed ChatResponse
|
||||||
|
"""
|
||||||
|
choices = []
|
||||||
|
for choice in response.choices:
|
||||||
|
msg = choice.message
|
||||||
|
message = ChatMessage(
|
||||||
|
role=msg.role if hasattr(msg, "role") else "assistant",
|
||||||
|
content=msg.content if hasattr(msg, "content") else None,
|
||||||
|
tool_calls=self._parse_tool_calls(
|
||||||
|
getattr(msg, "tool_calls", None)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
choices.append(
|
||||||
|
ChatResponseChoice(
|
||||||
|
index=choice.index if hasattr(choice, "index") else 0,
|
||||||
|
message=message,
|
||||||
|
finish_reason=getattr(choice, "finish_reason", None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
id=response.id if hasattr(response, "id") else "",
|
||||||
|
choices=choices,
|
||||||
|
usage=self._parse_usage(getattr(response, "usage", None)),
|
||||||
|
model=getattr(response, "model", None),
|
||||||
|
created=getattr(response, "created", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send a chat completion request to OpenRouter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID to use
|
||||||
|
messages: List of chat messages
|
||||||
|
stream: Whether to stream the response
|
||||||
|
max_tokens: Maximum tokens in response
|
||||||
|
temperature: Sampling temperature (0-2)
|
||||||
|
tools: List of tool definitions for function calling
|
||||||
|
tool_choice: How to handle tool selection ("auto", "none", etc.)
|
||||||
|
transforms: List of transforms (e.g., ["middle-out"])
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
||||||
|
"""
|
||||||
|
# Build request parameters
|
||||||
|
params: Dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"messages": self._convert_messages(messages),
|
||||||
|
"stream": stream,
|
||||||
|
"http_headers": self._get_headers(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Request usage stats in streaming responses
|
||||||
|
if stream:
|
||||||
|
params["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
|
if max_tokens is not None:
|
||||||
|
params["max_tokens"] = max_tokens
|
||||||
|
|
||||||
|
if temperature is not None:
|
||||||
|
params["temperature"] = temperature
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
params["tools"] = tools
|
||||||
|
params["tool_choice"] = tool_choice or "auto"
|
||||||
|
|
||||||
|
if transforms:
|
||||||
|
params["transforms"] = transforms
|
||||||
|
|
||||||
|
# Add any additional parameters
|
||||||
|
params.update(kwargs)
|
||||||
|
|
||||||
|
self.logger.debug(f"Sending chat request to model {model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.chat.send(**params)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_response(response)
|
||||||
|
else:
|
||||||
|
return self._parse_response(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Chat request failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _stream_response(self, response: Any) -> Iterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Process a streaming response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Streaming response from API
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects
|
||||||
|
"""
|
||||||
|
last_usage = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
for chunk in response:
|
||||||
|
# Check for errors
|
||||||
|
if hasattr(chunk, "error") and chunk.error:
|
||||||
|
yield StreamChunk(
|
||||||
|
id=getattr(chunk, "id", ""),
|
||||||
|
error=chunk.error.message if hasattr(chunk.error, "message") else str(chunk.error),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract delta content
|
||||||
|
delta_content = None
|
||||||
|
finish_reason = None
|
||||||
|
|
||||||
|
if hasattr(chunk, "choices") and chunk.choices:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
if hasattr(choice, "delta"):
|
||||||
|
delta = choice.delta
|
||||||
|
if hasattr(delta, "content") and delta.content:
|
||||||
|
delta_content = delta.content
|
||||||
|
finish_reason = getattr(choice, "finish_reason", None)
|
||||||
|
|
||||||
|
# Track usage from last chunk
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
last_usage = self._parse_usage(chunk.usage)
|
||||||
|
|
||||||
|
yield StreamChunk(
|
||||||
|
id=getattr(chunk, "id", ""),
|
||||||
|
delta_content=delta_content,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=last_usage if finish_reason else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Stream error: {e}")
|
||||||
|
yield StreamChunk(id="", error=str(e))
|
||||||
|
|
||||||
|
async def chat_async(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send an async chat completion request.
|
||||||
|
|
||||||
|
Note: Currently wraps the sync implementation.
|
||||||
|
TODO: Implement true async support when OpenRouter SDK supports it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID to use
|
||||||
|
messages: List of chat messages
|
||||||
|
stream: Whether to stream the response
|
||||||
|
max_tokens: Maximum tokens in response
|
||||||
|
temperature: Sampling temperature
|
||||||
|
tools: List of tool definitions
|
||||||
|
tool_choice: Tool selection mode
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse for non-streaming, AsyncIterator for streaming
|
||||||
|
"""
|
||||||
|
# For now, use sync implementation
|
||||||
|
# TODO: Add true async when SDK supports it
|
||||||
|
result = self.chat(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream and isinstance(result, Iterator):
|
||||||
|
# Convert sync iterator to async
|
||||||
|
async def async_iter() -> AsyncIterator[StreamChunk]:
|
||||||
|
for chunk in result:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return async_iter()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get OpenRouter account credit information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with credit info:
|
||||||
|
- total_credits: Total credits purchased
|
||||||
|
- used_credits: Credits used
|
||||||
|
- credits_left: Remaining credits
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If API request fails
|
||||||
|
"""
|
||||||
|
if not self.api_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/credits",
|
||||||
|
headers=self._get_headers(),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json().get("data", {})
|
||||||
|
total_credits = float(data.get("total_credits", 0))
|
||||||
|
total_usage = float(data.get("total_usage", 0))
|
||||||
|
credits_left = total_credits - total_usage
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_credits": total_credits,
|
||||||
|
"used_credits": total_usage,
|
||||||
|
"credits_left": credits_left,
|
||||||
|
"total_credits_formatted": f"${total_credits:.2f}",
|
||||||
|
"used_credits_formatted": f"${total_usage:.2f}",
|
||||||
|
"credits_left_formatted": f"${credits_left:.2f}",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to fetch credits: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the models cache to force a refresh."""
|
||||||
|
self._models_cache = None
|
||||||
|
self._raw_models_cache = None
|
||||||
|
self.logger.debug("Models cache cleared")
|
||||||
|
|
||||||
|
def get_effective_model_id(self, model_id: str, online_enabled: bool) -> str:
|
||||||
|
"""
|
||||||
|
Get the effective model ID with online suffix if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Base model ID
|
||||||
|
online_enabled: Whether online mode is enabled
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model ID with :online suffix if applicable
|
||||||
|
"""
|
||||||
|
if online_enabled and not model_id.endswith(":online"):
|
||||||
|
return f"{model_id}:online"
|
||||||
|
return model_id
|
||||||
|
|
||||||
|
def estimate_cost(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
input_tokens: int,
|
||||||
|
output_tokens: int,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Estimate the cost for a completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model ID
|
||||||
|
input_tokens: Number of input tokens
|
||||||
|
output_tokens: Number of output tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated cost in USD
|
||||||
|
"""
|
||||||
|
model = self.get_model(model_id)
|
||||||
|
if model and model.pricing:
|
||||||
|
input_cost = model.pricing.get("prompt", 0) * input_tokens / 1_000_000
|
||||||
|
output_cost = model.pricing.get("completion", 0) * output_tokens / 1_000_000
|
||||||
|
return input_cost + output_cost
|
||||||
|
|
||||||
|
# Fallback to default pricing if model not found
|
||||||
|
from oai.constants import MODEL_PRICING
|
||||||
|
|
||||||
|
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
|
||||||
|
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
|
||||||
|
return input_cost + output_cost
|
||||||
2
oai/py.typed
Normal file
2
oai/py.typed
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# Marker file for PEP 561
|
||||||
|
# This package supports type checking
|
||||||
5
oai/tui/__init__.py
Normal file
5
oai/tui/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Textual TUI interface for oAI."""
|
||||||
|
|
||||||
|
from oai.tui.app import oAIChatApp
|
||||||
|
|
||||||
|
__all__ = ["oAIChatApp"]
|
||||||
1055
oai/tui/app.py
Normal file
1055
oai/tui/app.py
Normal file
File diff suppressed because it is too large
Load Diff
21
oai/tui/screens/__init__.py
Normal file
21
oai/tui/screens/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""TUI screens for oAI."""
|
||||||
|
|
||||||
|
from oai.tui.screens.config_screen import ConfigScreen
|
||||||
|
from oai.tui.screens.conversation_selector import ConversationSelectorScreen
|
||||||
|
from oai.tui.screens.credits_screen import CreditsScreen
|
||||||
|
from oai.tui.screens.dialogs import AlertDialog, ConfirmDialog, InputDialog
|
||||||
|
from oai.tui.screens.help_screen import HelpScreen
|
||||||
|
from oai.tui.screens.model_selector import ModelSelectorScreen
|
||||||
|
from oai.tui.screens.stats_screen import StatsScreen
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AlertDialog",
|
||||||
|
"ConfirmDialog",
|
||||||
|
"ConfigScreen",
|
||||||
|
"ConversationSelectorScreen",
|
||||||
|
"CreditsScreen",
|
||||||
|
"InputDialog",
|
||||||
|
"HelpScreen",
|
||||||
|
"ModelSelectorScreen",
|
||||||
|
"StatsScreen",
|
||||||
|
]
|
||||||
107
oai/tui/screens/config_screen.py
Normal file
107
oai/tui/screens/config_screen.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""Configuration screen for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, Static
|
||||||
|
|
||||||
|
from oai.config.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigScreen(ModalScreen[None]):
|
||||||
|
"""Modal screen displaying configuration settings."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
ConfigScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfigScreen > Container {
|
||||||
|
width: 70;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfigScreen .header {
|
||||||
|
dock: top;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfigScreen .content {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
padding: 2;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfigScreen .footer {
|
||||||
|
dock: bottom;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings):
|
||||||
|
super().__init__()
|
||||||
|
self.settings = settings
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static("[bold]Configuration[/]", classes="header")
|
||||||
|
with Vertical(classes="content"):
|
||||||
|
yield Static(self._get_config_text(), markup=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Close", id="close", variant="primary")
|
||||||
|
|
||||||
|
def _get_config_text(self) -> str:
|
||||||
|
"""Generate the configuration text."""
|
||||||
|
from oai.constants import DEFAULT_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
# API Key display
|
||||||
|
api_key_display = "***" + self.settings.api_key[-4:] if self.settings.api_key else "Not set"
|
||||||
|
|
||||||
|
# System prompt display
|
||||||
|
if self.settings.default_system_prompt is None:
|
||||||
|
system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
|
||||||
|
elif self.settings.default_system_prompt == "":
|
||||||
|
system_prompt_display = "[blank]"
|
||||||
|
else:
|
||||||
|
prompt = self.settings.default_system_prompt
|
||||||
|
system_prompt_display = prompt[:50] + "..." if len(prompt) > 50 else prompt
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
[bold cyan]═══ CONFIGURATION ═══[/]
|
||||||
|
|
||||||
|
[bold]API Key:[/] {api_key_display}
|
||||||
|
[bold]Base URL:[/] {self.settings.base_url}
|
||||||
|
[bold]Default Model:[/] {self.settings.default_model or "Not set"}
|
||||||
|
|
||||||
|
[bold]System Prompt:[/] {system_prompt_display}
|
||||||
|
|
||||||
|
[bold]Streaming:[/] {"on" if self.settings.stream_enabled else "off"}
|
||||||
|
[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f}
|
||||||
|
[bold]Max Tokens:[/] {self.settings.max_tokens}
|
||||||
|
[bold]Default Online:[/] {"on" if self.settings.default_online_mode else "off"}
|
||||||
|
[bold]Log Level:[/] {self.settings.log_level}
|
||||||
|
|
||||||
|
[dim]Use /config [setting] [value] to modify settings[/]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key in ("escape", "enter"):
|
||||||
|
self.dismiss()
|
||||||
205
oai/tui/screens/conversation_selector.py
Normal file
205
oai/tui/screens/conversation_selector.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Conversation selector screen for oAI TUI."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, DataTable, Input, Static
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationSelectorScreen(ModalScreen[Optional[dict]]):
|
||||||
|
"""Modal screen for selecting a saved conversation."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
ConversationSelectorScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen > Container {
|
||||||
|
width: 80%;
|
||||||
|
height: 70%;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
layout: vertical;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen .header {
|
||||||
|
height: 3;
|
||||||
|
width: 100%;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
content-align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen .search-input {
|
||||||
|
height: 3;
|
||||||
|
width: 100%;
|
||||||
|
background: #2a2a2a;
|
||||||
|
border: solid #555555;
|
||||||
|
margin: 0 0 1 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen .search-input:focus {
|
||||||
|
border: solid #888888;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen DataTable {
|
||||||
|
height: 1fr;
|
||||||
|
width: 100%;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen .footer {
|
||||||
|
height: 5;
|
||||||
|
width: 100%;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConversationSelectorScreen Button {
|
||||||
|
margin: 0 1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, conversations: List[dict]):
|
||||||
|
super().__init__()
|
||||||
|
self.all_conversations = conversations
|
||||||
|
self.filtered_conversations = conversations
|
||||||
|
self.selected_conversation: Optional[dict] = None
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static(
|
||||||
|
f"[bold]Load Conversation[/] [dim]({len(self.all_conversations)} saved)[/]",
|
||||||
|
classes="header"
|
||||||
|
)
|
||||||
|
yield Input(placeholder="Search conversations...", id="search-input", classes="search-input")
|
||||||
|
yield DataTable(id="conv-table", cursor_type="row", show_header=True, zebra_stripes=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Load", id="load", variant="success")
|
||||||
|
yield Button("Cancel", id="cancel", variant="error")
|
||||||
|
|
||||||
|
def on_mount(self) -> None:
|
||||||
|
"""Initialize the table when mounted."""
|
||||||
|
table = self.query_one("#conv-table", DataTable)
|
||||||
|
|
||||||
|
# Add columns
|
||||||
|
table.add_column("#", width=5)
|
||||||
|
table.add_column("Name", width=40)
|
||||||
|
table.add_column("Messages", width=12)
|
||||||
|
table.add_column("Last Saved", width=20)
|
||||||
|
|
||||||
|
# Populate table
|
||||||
|
self._populate_table()
|
||||||
|
|
||||||
|
# Focus table if list is small (fits on screen), otherwise focus search
|
||||||
|
if len(self.all_conversations) <= 10:
|
||||||
|
table.focus()
|
||||||
|
else:
|
||||||
|
search_input = self.query_one("#search-input", Input)
|
||||||
|
search_input.focus()
|
||||||
|
|
||||||
|
def _populate_table(self) -> None:
|
||||||
|
"""Populate the table with conversations."""
|
||||||
|
table = self.query_one("#conv-table", DataTable)
|
||||||
|
table.clear()
|
||||||
|
|
||||||
|
for idx, conv in enumerate(self.filtered_conversations, 1):
|
||||||
|
name = conv.get("name", "Unknown")
|
||||||
|
message_count = str(conv.get("message_count", 0))
|
||||||
|
last_saved = conv.get("last_saved", "Unknown")
|
||||||
|
|
||||||
|
# Format timestamp if it's a full datetime
|
||||||
|
if "T" in last_saved or len(last_saved) > 20:
|
||||||
|
try:
|
||||||
|
# Truncate to just date and time
|
||||||
|
last_saved = last_saved[:19].replace("T", " ")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
table.add_row(
|
||||||
|
str(idx),
|
||||||
|
name,
|
||||||
|
message_count,
|
||||||
|
last_saved,
|
||||||
|
key=str(idx)
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_input_changed(self, event: Input.Changed) -> None:
|
||||||
|
"""Filter conversations based on search input."""
|
||||||
|
if event.input.id != "search-input":
|
||||||
|
return
|
||||||
|
|
||||||
|
search_term = event.value.lower()
|
||||||
|
|
||||||
|
if not search_term:
|
||||||
|
self.filtered_conversations = self.all_conversations
|
||||||
|
else:
|
||||||
|
self.filtered_conversations = [
|
||||||
|
c for c in self.all_conversations
|
||||||
|
if search_term in c.get("name", "").lower()
|
||||||
|
]
|
||||||
|
|
||||||
|
self._populate_table()
|
||||||
|
|
||||||
|
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
|
||||||
|
"""Handle row selection (click)."""
|
||||||
|
try:
|
||||||
|
row_index = int(event.row_key.value) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_conversations):
|
||||||
|
self.selected_conversation = self.filtered_conversations[row_index]
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_data_table_row_highlighted(self, event) -> None:
|
||||||
|
"""Handle row highlight (arrow key navigation)."""
|
||||||
|
try:
|
||||||
|
table = self.query_one("#conv-table", DataTable)
|
||||||
|
if table.cursor_row is not None:
|
||||||
|
row_data = table.get_row_at(table.cursor_row)
|
||||||
|
if row_data:
|
||||||
|
row_index = int(row_data[0]) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_conversations):
|
||||||
|
self.selected_conversation = self.filtered_conversations[row_index]
|
||||||
|
except (ValueError, IndexError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
if event.button.id == "load":
|
||||||
|
if self.selected_conversation:
|
||||||
|
self.dismiss(self.selected_conversation)
|
||||||
|
else:
|
||||||
|
self.dismiss(None)
|
||||||
|
else:
|
||||||
|
self.dismiss(None)
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key == "escape":
|
||||||
|
self.dismiss(None)
|
||||||
|
elif event.key == "enter":
|
||||||
|
# If in search input, move to table
|
||||||
|
search_input = self.query_one("#search-input", Input)
|
||||||
|
if search_input.has_focus:
|
||||||
|
table = self.query_one("#conv-table", DataTable)
|
||||||
|
table.focus()
|
||||||
|
# If in table, select current row
|
||||||
|
else:
|
||||||
|
table = self.query_one("#conv-table", DataTable)
|
||||||
|
if table.cursor_row is not None:
|
||||||
|
try:
|
||||||
|
row_data = table.get_row_at(table.cursor_row)
|
||||||
|
if row_data:
|
||||||
|
row_index = int(row_data[0]) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_conversations):
|
||||||
|
selected = self.filtered_conversations[row_index]
|
||||||
|
self.dismiss(selected)
|
||||||
|
except (ValueError, IndexError, AttributeError):
|
||||||
|
if self.selected_conversation:
|
||||||
|
self.dismiss(self.selected_conversation)
|
||||||
125
oai/tui/screens/credits_screen.py
Normal file
125
oai/tui/screens/credits_screen.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Credits screen for oAI TUI."""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, Static
|
||||||
|
|
||||||
|
from oai.core.client import AIClient
|
||||||
|
|
||||||
|
|
||||||
|
class CreditsScreen(ModalScreen[None]):
|
||||||
|
"""Modal screen displaying account credits."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
CreditsScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
CreditsScreen > Container {
|
||||||
|
width: 60;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
CreditsScreen .header {
|
||||||
|
dock: top;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
CreditsScreen .content {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
padding: 2;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
CreditsScreen .footer {
|
||||||
|
dock: bottom;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, client: AIClient):
|
||||||
|
super().__init__()
|
||||||
|
self.client = client
|
||||||
|
self.credits_data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static("[bold]Account Credits[/]", classes="header")
|
||||||
|
with Vertical(classes="content"):
|
||||||
|
yield Static("[dim]Loading...[/]", id="credits-content", markup=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Close", id="close", variant="primary")
|
||||||
|
|
||||||
|
def on_mount(self) -> None:
|
||||||
|
"""Fetch credits when mounted."""
|
||||||
|
self.fetch_credits()
|
||||||
|
|
||||||
|
def fetch_credits(self) -> None:
|
||||||
|
"""Fetch and display credits information."""
|
||||||
|
try:
|
||||||
|
self.credits_data = self.client.provider.get_credits()
|
||||||
|
content = self.query_one("#credits-content", Static)
|
||||||
|
content.update(self._get_credits_text())
|
||||||
|
except Exception as e:
|
||||||
|
content = self.query_one("#credits-content", Static)
|
||||||
|
content.update(f"[red]Error fetching credits:[/]\n{str(e)}")
|
||||||
|
|
||||||
|
def _get_credits_text(self) -> str:
|
||||||
|
"""Generate the credits text."""
|
||||||
|
if not self.credits_data:
|
||||||
|
return "[yellow]No credit information available[/]"
|
||||||
|
|
||||||
|
total = self.credits_data.get("total_credits", 0)
|
||||||
|
used = self.credits_data.get("used_credits", 0)
|
||||||
|
remaining = self.credits_data.get("credits_left", 0)
|
||||||
|
|
||||||
|
# Calculate percentage used
|
||||||
|
if total > 0:
|
||||||
|
percent_used = (used / total) * 100
|
||||||
|
percent_remaining = (remaining / total) * 100
|
||||||
|
else:
|
||||||
|
percent_used = 0
|
||||||
|
percent_remaining = 0
|
||||||
|
|
||||||
|
# Color code based on remaining credits
|
||||||
|
if percent_remaining > 50:
|
||||||
|
remaining_color = "green"
|
||||||
|
elif percent_remaining > 20:
|
||||||
|
remaining_color = "yellow"
|
||||||
|
else:
|
||||||
|
remaining_color = "red"
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
[bold cyan]═══ OPENROUTER CREDITS ═══[/]
|
||||||
|
|
||||||
|
[bold]Total Credits:[/] ${total:.2f}
|
||||||
|
[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/]
|
||||||
|
[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/]
|
||||||
|
|
||||||
|
[dim]Visit openrouter.ai to add more credits[/]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key in ("escape", "enter"):
|
||||||
|
self.dismiss()
|
||||||
236
oai/tui/screens/dialogs.py
Normal file
236
oai/tui/screens/dialogs.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""Modal dialog screens for oAI TUI."""
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Horizontal, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, Input, Label, Static
|
||||||
|
|
||||||
|
|
||||||
|
class ConfirmDialog(ModalScreen[bool]):
|
||||||
|
"""A confirmation dialog with Yes/No buttons."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
ConfirmDialog {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfirmDialog > Container {
|
||||||
|
width: 60;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: solid #555555;
|
||||||
|
padding: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfirmDialog Label {
|
||||||
|
width: 100%;
|
||||||
|
content-align: center middle;
|
||||||
|
margin-bottom: 2;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfirmDialog Horizontal {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfirmDialog Button {
|
||||||
|
margin: 0 1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
title: str = "Confirm",
|
||||||
|
yes_label: str = "Yes",
|
||||||
|
no_label: str = "No",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.message = message
|
||||||
|
self.title = title
|
||||||
|
self.yes_label = yes_label
|
||||||
|
self.no_label = no_label
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the dialog."""
|
||||||
|
with Container():
|
||||||
|
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
|
||||||
|
yield Label(self.message)
|
||||||
|
with Horizontal():
|
||||||
|
yield Button(self.yes_label, id="yes", variant="success")
|
||||||
|
yield Button(self.no_label, id="no", variant="error")
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
if event.button.id == "yes":
|
||||||
|
self.dismiss(True)
|
||||||
|
else:
|
||||||
|
self.dismiss(False)
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key == "escape":
|
||||||
|
self.dismiss(False)
|
||||||
|
elif event.key == "enter":
|
||||||
|
self.dismiss(True)
|
||||||
|
|
||||||
|
|
||||||
|
class InputDialog(ModalScreen[Optional[str]]):
|
||||||
|
"""An input dialog for text entry."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
InputDialog {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog > Container {
|
||||||
|
width: 70;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: solid #555555;
|
||||||
|
padding: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog Label {
|
||||||
|
width: 100%;
|
||||||
|
margin-bottom: 1;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog Input {
|
||||||
|
width: 100%;
|
||||||
|
margin-bottom: 2;
|
||||||
|
background: #3a3a3a;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog Input:focus {
|
||||||
|
border: solid #888888;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog Horizontal {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputDialog Button {
|
||||||
|
margin: 0 1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
title: str = "Input",
|
||||||
|
default: str = "",
|
||||||
|
placeholder: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.message = message
|
||||||
|
self.title = title
|
||||||
|
self.default = default
|
||||||
|
self.placeholder = placeholder
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the dialog."""
|
||||||
|
with Container():
|
||||||
|
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
|
||||||
|
yield Label(self.message)
|
||||||
|
yield Input(
|
||||||
|
value=self.default,
|
||||||
|
placeholder=self.placeholder,
|
||||||
|
id="input-field"
|
||||||
|
)
|
||||||
|
with Horizontal():
|
||||||
|
yield Button("OK", id="ok", variant="primary")
|
||||||
|
yield Button("Cancel", id="cancel")
|
||||||
|
|
||||||
|
def on_mount(self) -> None:
|
||||||
|
"""Focus the input field when mounted."""
|
||||||
|
input_field = self.query_one("#input-field", Input)
|
||||||
|
input_field.focus()
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
if event.button.id == "ok":
|
||||||
|
input_field = self.query_one("#input-field", Input)
|
||||||
|
self.dismiss(input_field.value)
|
||||||
|
else:
|
||||||
|
self.dismiss(None)
|
||||||
|
|
||||||
|
def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||||
|
"""Handle Enter key in input field."""
|
||||||
|
self.dismiss(event.value)
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key == "escape":
|
||||||
|
self.dismiss(None)
|
||||||
|
|
||||||
|
|
||||||
|
class AlertDialog(ModalScreen[None]):
|
||||||
|
"""A simple alert/message dialog."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
AlertDialog {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
AlertDialog > Container {
|
||||||
|
width: 60;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: solid #555555;
|
||||||
|
padding: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
AlertDialog Label {
|
||||||
|
width: 100%;
|
||||||
|
content-align: center middle;
|
||||||
|
margin-bottom: 2;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
AlertDialog Horizontal {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, title: str = "Alert", variant: str = "default"):
|
||||||
|
super().__init__()
|
||||||
|
self.message = message
|
||||||
|
self.title = title
|
||||||
|
self.variant = variant
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the dialog."""
|
||||||
|
# Choose color based on variant (using design system)
|
||||||
|
color = "$primary"
|
||||||
|
if self.variant == "error":
|
||||||
|
color = "$error"
|
||||||
|
elif self.variant == "success":
|
||||||
|
color = "$success"
|
||||||
|
elif self.variant == "warning":
|
||||||
|
color = "$warning"
|
||||||
|
|
||||||
|
with Container():
|
||||||
|
yield Static(f"[bold {color}]{self.title}[/]", classes="dialog-title")
|
||||||
|
yield Label(self.message)
|
||||||
|
with Horizontal():
|
||||||
|
yield Button("OK", id="ok", variant="primary")
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key in ("escape", "enter"):
|
||||||
|
self.dismiss()
|
||||||
140
oai/tui/screens/help_screen.py
Normal file
140
oai/tui/screens/help_screen.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""Help screen for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, Static
|
||||||
|
|
||||||
|
|
||||||
|
class HelpScreen(ModalScreen[None]):
|
||||||
|
"""Modal screen displaying help and commands."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
HelpScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
HelpScreen > Container {
|
||||||
|
width: 90%;
|
||||||
|
height: 85%;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
HelpScreen .header {
|
||||||
|
dock: top;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
HelpScreen .content {
|
||||||
|
height: 1fr;
|
||||||
|
background: #1e1e1e;
|
||||||
|
padding: 2;
|
||||||
|
overflow-y: auto;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
HelpScreen .footer {
|
||||||
|
dock: bottom;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static("[bold]oAI Help & Commands[/]", classes="header")
|
||||||
|
with Vertical(classes="content"):
|
||||||
|
yield Static(self._get_help_text(), markup=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Close", id="close", variant="primary")
|
||||||
|
|
||||||
|
def _get_help_text(self) -> str:
|
||||||
|
"""Generate the help text."""
|
||||||
|
return """
|
||||||
|
[bold cyan]═══ KEYBOARD SHORTCUTS ═══[/]
|
||||||
|
[bold]F1[/] Show this help (Ctrl+H may not work)
|
||||||
|
[bold]F2[/] Open model selector (Ctrl+M may not work)
|
||||||
|
[bold]F3[/] Copy last AI response to clipboard
|
||||||
|
[bold]Ctrl+S[/] Show session statistics
|
||||||
|
[bold]Ctrl+L[/] Clear chat display
|
||||||
|
[bold]Ctrl+P[/] Show previous message
|
||||||
|
[bold]Ctrl+N[/] Show next message
|
||||||
|
[bold]Ctrl+Y[/] Copy last AI response (alternative to F3)
|
||||||
|
[bold]Ctrl+Q[/] Quit application
|
||||||
|
[bold]Up/Down[/] Navigate input history
|
||||||
|
[bold]ESC[/] Close dialogs
|
||||||
|
[dim]Note: Some Ctrl keys may be captured by your terminal[/]
|
||||||
|
|
||||||
|
[bold cyan]═══ SLASH COMMANDS ═══[/]
|
||||||
|
[bold yellow]Session Control:[/]
|
||||||
|
/reset Clear conversation history (with confirmation)
|
||||||
|
/clear Clear the chat display
|
||||||
|
/memory on/off Toggle conversation memory
|
||||||
|
/online on/off Toggle online search mode
|
||||||
|
/exit, /quit, /bye Exit the application
|
||||||
|
|
||||||
|
[bold yellow]Model & Configuration:[/]
|
||||||
|
/model [search] Open model selector with optional search
|
||||||
|
/config View configuration settings
|
||||||
|
/config api Set API key (prompts for input)
|
||||||
|
/config stream on Enable streaming responses
|
||||||
|
/system [prompt] Set session system prompt
|
||||||
|
/maxtoken [n] Set session token limit
|
||||||
|
|
||||||
|
[bold yellow]Conversation Management:[/]
|
||||||
|
/save [name] Save current conversation
|
||||||
|
/load [name] Load saved conversation (shows picker if no name)
|
||||||
|
/list List all saved conversations
|
||||||
|
/delete <name> Delete a saved conversation
|
||||||
|
|
||||||
|
[bold yellow]Export:[/]
|
||||||
|
/export md [file] Export as Markdown
|
||||||
|
/export json [file] Export as JSON
|
||||||
|
/export html [file] Export as HTML
|
||||||
|
|
||||||
|
[bold yellow]History Navigation:[/]
|
||||||
|
/prev Show previous message in history
|
||||||
|
/next Show next message in history
|
||||||
|
|
||||||
|
[bold yellow]MCP (Model Context Protocol):[/]
|
||||||
|
/mcp on Enable MCP file access
|
||||||
|
/mcp off Disable MCP
|
||||||
|
/mcp status Show MCP status
|
||||||
|
/mcp add <path> Add folder for file access
|
||||||
|
/mcp list List registered folders
|
||||||
|
/mcp write Toggle write permissions
|
||||||
|
|
||||||
|
[bold yellow]Information & Utilities:[/]
|
||||||
|
/help Show this help screen
|
||||||
|
/stats Show session statistics
|
||||||
|
/credits Check account credits
|
||||||
|
/retry Retry last prompt
|
||||||
|
/paste Paste from clipboard and send
|
||||||
|
|
||||||
|
[bold cyan]═══ TIPS ═══[/]
|
||||||
|
• Type [bold]/[/] to see command suggestions with [bold]Tab[/] to autocomplete
|
||||||
|
• Use [bold]Up/Down arrows[/] to navigate your input history
|
||||||
|
• Type [bold]//[/] at start to escape commands (sends /help as literal message)
|
||||||
|
• All messages support [bold]Markdown formatting[/] with syntax highlighting
|
||||||
|
• Responses stream in real-time for better interactivity
|
||||||
|
• Enable MCP to let AI access your local files and databases
|
||||||
|
• Use [bold]F1[/] or [bold]F2[/] if Ctrl shortcuts don't work in your terminal
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key in ("escape", "enter"):
|
||||||
|
self.dismiss()
|
||||||
254
oai/tui/screens/model_selector.py
Normal file
254
oai/tui/screens/model_selector.py
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
"""Model selector screen for oAI TUI."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, DataTable, Input, Label, Static
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSelectorScreen(ModalScreen[Optional[dict]]):
|
||||||
|
"""Modal screen for selecting an AI model."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
ModelSelectorScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen > Container {
|
||||||
|
width: 90%;
|
||||||
|
height: 85%;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
layout: vertical;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen .header {
|
||||||
|
height: 3;
|
||||||
|
width: 100%;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
content-align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen .search-input {
|
||||||
|
height: 3;
|
||||||
|
width: 100%;
|
||||||
|
background: #2a2a2a;
|
||||||
|
border: solid #555555;
|
||||||
|
margin: 0 0 1 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen .search-input:focus {
|
||||||
|
border: solid #888888;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen DataTable {
|
||||||
|
height: 1fr;
|
||||||
|
width: 100%;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen .footer {
|
||||||
|
height: 5;
|
||||||
|
width: 100%;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelSelectorScreen Button {
|
||||||
|
margin: 0 1;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, models: List[dict], current_model: Optional[str] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.all_models = models
|
||||||
|
self.filtered_models = models
|
||||||
|
self.current_model = current_model
|
||||||
|
self.selected_model: Optional[dict] = None
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static(
|
||||||
|
f"[bold]Select Model[/] [dim]({len(self.all_models)} available)[/]",
|
||||||
|
classes="header"
|
||||||
|
)
|
||||||
|
yield Input(placeholder="Search to filter models...", id="search-input", classes="search-input")
|
||||||
|
yield DataTable(id="model-table", cursor_type="row", show_header=True, zebra_stripes=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Select", id="select", variant="success")
|
||||||
|
yield Button("Cancel", id="cancel", variant="error")
|
||||||
|
|
||||||
|
def on_mount(self) -> None:
|
||||||
|
"""Initialize the table when mounted."""
|
||||||
|
table = self.query_one("#model-table", DataTable)
|
||||||
|
|
||||||
|
# Add columns
|
||||||
|
table.add_column("#", width=5)
|
||||||
|
table.add_column("Model ID", width=35)
|
||||||
|
table.add_column("Name", width=30)
|
||||||
|
table.add_column("Context", width=10)
|
||||||
|
table.add_column("Price", width=12)
|
||||||
|
table.add_column("Img", width=4)
|
||||||
|
table.add_column("Tools", width=6)
|
||||||
|
table.add_column("Online", width=7)
|
||||||
|
|
||||||
|
# Populate table
|
||||||
|
self._populate_table()
|
||||||
|
|
||||||
|
# Focus table if list is small (fits on screen), otherwise focus search
|
||||||
|
if len(self.filtered_models) <= 20:
|
||||||
|
table.focus()
|
||||||
|
else:
|
||||||
|
search_input = self.query_one("#search-input", Input)
|
||||||
|
search_input.focus()
|
||||||
|
|
||||||
|
def _populate_table(self) -> None:
|
||||||
|
"""Populate the table with models."""
|
||||||
|
table = self.query_one("#model-table", DataTable)
|
||||||
|
table.clear()
|
||||||
|
|
||||||
|
rows_added = 0
|
||||||
|
for idx, model in enumerate(self.filtered_models, 1):
|
||||||
|
try:
|
||||||
|
model_id = model.get("id", "")
|
||||||
|
name = model.get("name", "")
|
||||||
|
context = str(model.get("context_length", "N/A"))
|
||||||
|
|
||||||
|
# Format pricing
|
||||||
|
pricing = model.get("pricing", {})
|
||||||
|
prompt_price = pricing.get("prompt", "0")
|
||||||
|
completion_price = pricing.get("completion", "0")
|
||||||
|
|
||||||
|
# Convert to numbers and format
|
||||||
|
try:
|
||||||
|
prompt = float(prompt_price) * 1000000 # Convert to per 1M tokens
|
||||||
|
completion = float(completion_price) * 1000000
|
||||||
|
if prompt == 0 and completion == 0:
|
||||||
|
price = "Free"
|
||||||
|
else:
|
||||||
|
price = f"${prompt:.2f}/${completion:.2f}"
|
||||||
|
except:
|
||||||
|
price = "N/A"
|
||||||
|
|
||||||
|
# Check capabilities
|
||||||
|
architecture = model.get("architecture", {})
|
||||||
|
modality = architecture.get("modality", "")
|
||||||
|
supported_params = model.get("supported_parameters", [])
|
||||||
|
|
||||||
|
# Vision support: check if modality contains "image"
|
||||||
|
supports_vision = "image" in modality
|
||||||
|
|
||||||
|
# Tool support: check if "tools" or "tool_choice" in supported_parameters
|
||||||
|
supports_tools = "tools" in supported_params or "tool_choice" in supported_params
|
||||||
|
|
||||||
|
# Online support: check if model can use :online suffix (most models can)
|
||||||
|
# Models that already have :online in their ID support it
|
||||||
|
supports_online = ":online" in model_id or model_id not in ["openrouter/free"]
|
||||||
|
|
||||||
|
# Format capability indicators
|
||||||
|
img_indicator = "✓" if supports_vision else "-"
|
||||||
|
tools_indicator = "✓" if supports_tools else "-"
|
||||||
|
web_indicator = "✓" if supports_online else "-"
|
||||||
|
|
||||||
|
# Add row
|
||||||
|
table.add_row(
|
||||||
|
str(idx),
|
||||||
|
model_id,
|
||||||
|
name,
|
||||||
|
context,
|
||||||
|
price,
|
||||||
|
img_indicator,
|
||||||
|
tools_indicator,
|
||||||
|
web_indicator,
|
||||||
|
key=str(idx)
|
||||||
|
)
|
||||||
|
rows_added += 1
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Silently skip rows that fail to add
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_input_changed(self, event: Input.Changed) -> None:
|
||||||
|
"""Filter models based on search input."""
|
||||||
|
if event.input.id != "search-input":
|
||||||
|
return
|
||||||
|
|
||||||
|
search_term = event.value.lower()
|
||||||
|
|
||||||
|
if not search_term:
|
||||||
|
self.filtered_models = self.all_models
|
||||||
|
else:
|
||||||
|
self.filtered_models = [
|
||||||
|
m for m in self.all_models
|
||||||
|
if search_term in m.get("id", "").lower()
|
||||||
|
or search_term in m.get("name", "").lower()
|
||||||
|
]
|
||||||
|
|
||||||
|
self._populate_table()
|
||||||
|
|
||||||
|
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
|
||||||
|
"""Handle row selection (click or arrow navigation)."""
|
||||||
|
try:
|
||||||
|
row_index = int(event.row_key.value) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_models):
|
||||||
|
self.selected_model = self.filtered_models[row_index]
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_data_table_row_highlighted(self, event) -> None:
|
||||||
|
"""Handle row highlight (arrow key navigation)."""
|
||||||
|
try:
|
||||||
|
table = self.query_one("#model-table", DataTable)
|
||||||
|
if table.cursor_row is not None:
|
||||||
|
row_data = table.get_row_at(table.cursor_row)
|
||||||
|
if row_data:
|
||||||
|
row_index = int(row_data[0]) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_models):
|
||||||
|
self.selected_model = self.filtered_models[row_index]
|
||||||
|
except (ValueError, IndexError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
if event.button.id == "select":
|
||||||
|
if self.selected_model:
|
||||||
|
self.dismiss(self.selected_model)
|
||||||
|
else:
|
||||||
|
# No selection, dismiss without result
|
||||||
|
self.dismiss(None)
|
||||||
|
else:
|
||||||
|
self.dismiss(None)
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key == "escape":
|
||||||
|
self.dismiss(None)
|
||||||
|
elif event.key == "enter":
|
||||||
|
# If in search input, move to table
|
||||||
|
search_input = self.query_one("#search-input", Input)
|
||||||
|
if search_input.has_focus:
|
||||||
|
table = self.query_one("#model-table", DataTable)
|
||||||
|
table.focus()
|
||||||
|
# If in table or anywhere else, select current row
|
||||||
|
else:
|
||||||
|
table = self.query_one("#model-table", DataTable)
|
||||||
|
# Get the currently highlighted row
|
||||||
|
if table.cursor_row is not None:
|
||||||
|
try:
|
||||||
|
row_key = table.get_row_at(table.cursor_row)
|
||||||
|
if row_key:
|
||||||
|
row_index = int(row_key[0]) - 1
|
||||||
|
if 0 <= row_index < len(self.filtered_models):
|
||||||
|
selected = self.filtered_models[row_index]
|
||||||
|
self.dismiss(selected)
|
||||||
|
except (ValueError, IndexError, AttributeError):
|
||||||
|
# Fall back to previously selected model
|
||||||
|
if self.selected_model:
|
||||||
|
self.dismiss(self.selected_model)
|
||||||
129
oai/tui/screens/stats_screen.py
Normal file
129
oai/tui/screens/stats_screen.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""Statistics screen for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Container, Vertical
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.widgets import Button, Static
|
||||||
|
|
||||||
|
from oai.core.session import ChatSession
|
||||||
|
|
||||||
|
|
||||||
|
class StatsScreen(ModalScreen[None]):
|
||||||
|
"""Modal screen displaying session statistics."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
StatsScreen {
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatsScreen > Container {
|
||||||
|
width: 70;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
border: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatsScreen .header {
|
||||||
|
dock: top;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatsScreen .content {
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #1e1e1e;
|
||||||
|
padding: 2;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatsScreen .footer {
|
||||||
|
dock: bottom;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
padding: 1 2;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session: ChatSession):
|
||||||
|
super().__init__()
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the screen."""
|
||||||
|
with Container():
|
||||||
|
yield Static("[bold]Session Statistics[/]", classes="header")
|
||||||
|
with Vertical(classes="content"):
|
||||||
|
yield Static(self._get_stats_text(), markup=True)
|
||||||
|
with Vertical(classes="footer"):
|
||||||
|
yield Button("Close", id="close", variant="primary")
|
||||||
|
|
||||||
|
def _get_stats_text(self) -> str:
|
||||||
|
"""Generate the statistics text."""
|
||||||
|
stats = self.session.stats
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
avg_input = stats.total_input_tokens // stats.message_count if stats.message_count > 0 else 0
|
||||||
|
avg_output = stats.total_output_tokens // stats.message_count if stats.message_count > 0 else 0
|
||||||
|
avg_cost = stats.total_cost / stats.message_count if stats.message_count > 0 else 0
|
||||||
|
|
||||||
|
# Get model info
|
||||||
|
model_name = "None"
|
||||||
|
model_context = "N/A"
|
||||||
|
if self.session.selected_model:
|
||||||
|
model_name = self.session.selected_model.get("name", "Unknown")
|
||||||
|
model_context = str(self.session.selected_model.get("context_length", "N/A"))
|
||||||
|
|
||||||
|
# MCP status
|
||||||
|
mcp_status = "Disabled"
|
||||||
|
if self.session.mcp_manager and self.session.mcp_manager.enabled:
|
||||||
|
mode = self.session.mcp_manager.mode
|
||||||
|
if mode == "files":
|
||||||
|
write = " (Write)" if self.session.mcp_manager.write_enabled else ""
|
||||||
|
mcp_status = f"Enabled - Files{write}"
|
||||||
|
elif mode == "database":
|
||||||
|
db_idx = self.session.mcp_manager.selected_db_index
|
||||||
|
if db_idx is not None:
|
||||||
|
db_name = self.session.mcp_manager.databases[db_idx]["name"]
|
||||||
|
mcp_status = f"Enabled - Database ({db_name})"
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
[bold cyan]═══ SESSION INFO ═══[/]
|
||||||
|
[bold]Messages:[/] {stats.message_count}
|
||||||
|
[bold]Current Model:[/] {model_name}
|
||||||
|
[bold]Context Length:[/] {model_context}
|
||||||
|
[bold]Memory:[/] {"Enabled" if self.session.memory_enabled else "Disabled"}
|
||||||
|
[bold]Online Mode:[/] {"Enabled" if self.session.online_enabled else "Disabled"}
|
||||||
|
[bold]MCP:[/] {mcp_status}
|
||||||
|
|
||||||
|
[bold cyan]═══ TOKEN USAGE ═══[/]
|
||||||
|
[bold]Input Tokens:[/] {stats.total_input_tokens:,}
|
||||||
|
[bold]Output Tokens:[/] {stats.total_output_tokens:,}
|
||||||
|
[bold]Total Tokens:[/] {stats.total_tokens:,}
|
||||||
|
|
||||||
|
[bold]Avg Input/Msg:[/] {avg_input:,}
|
||||||
|
[bold]Avg Output/Msg:[/] {avg_output:,}
|
||||||
|
|
||||||
|
[bold cyan]═══ COSTS ═══[/]
|
||||||
|
[bold]Total Cost:[/] ${stats.total_cost:.6f}
|
||||||
|
[bold]Avg Cost/Msg:[/] ${avg_cost:.6f}
|
||||||
|
|
||||||
|
[bold cyan]═══ HISTORY ═══[/]
|
||||||
|
[bold]History Size:[/] {len(self.session.history)} entries
|
||||||
|
[bold]Current Index:[/] {self.session.current_index + 1 if self.session.history else 0}
|
||||||
|
[bold]Memory Start:[/] {self.session.memory_start_index + 1}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""Handle button press."""
|
||||||
|
self.dismiss()
|
||||||
|
|
||||||
|
def on_key(self, event) -> None:
|
||||||
|
"""Handle keyboard shortcuts."""
|
||||||
|
if event.key in ("escape", "enter"):
|
||||||
|
self.dismiss()
|
||||||
169
oai/tui/styles.tcss
Normal file
169
oai/tui/styles.tcss
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
/* Textual CSS for oAI TUI - Using Textual Design System */
|
||||||
|
|
||||||
|
Screen {
|
||||||
|
background: $background;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
Header {
|
||||||
|
dock: top;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
color: #cccccc;
|
||||||
|
padding: 0 1;
|
||||||
|
border-bottom: solid #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatDisplay {
|
||||||
|
background: $background;
|
||||||
|
border: none;
|
||||||
|
padding: 1;
|
||||||
|
scrollbar-background: $background;
|
||||||
|
scrollbar-color: $primary;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
UserMessageWidget {
|
||||||
|
margin: 0 0 1 0;
|
||||||
|
padding: 1;
|
||||||
|
background: $surface;
|
||||||
|
border-left: thick $success;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
SystemMessageWidget {
|
||||||
|
margin: 0 0 1 0;
|
||||||
|
padding: 1;
|
||||||
|
background: #2a2a2a;
|
||||||
|
border-left: thick #888888;
|
||||||
|
height: auto;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
AssistantMessageWidget {
|
||||||
|
margin: 0 0 1 0;
|
||||||
|
padding: 1;
|
||||||
|
background: $panel;
|
||||||
|
border-left: thick $accent;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
#assistant-label {
|
||||||
|
margin-bottom: 1;
|
||||||
|
color: #cccccc;
|
||||||
|
}
|
||||||
|
|
||||||
|
#assistant-content {
|
||||||
|
height: auto;
|
||||||
|
max-height: 100%;
|
||||||
|
color: #cccccc;
|
||||||
|
link-color: #888888;
|
||||||
|
link-style: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
InputBar {
|
||||||
|
dock: bottom;
|
||||||
|
height: auto;
|
||||||
|
background: #2d2d2d;
|
||||||
|
align: center middle;
|
||||||
|
border-top: solid #555555;
|
||||||
|
padding: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-prefix {
|
||||||
|
width: auto;
|
||||||
|
padding: 0 1;
|
||||||
|
content-align: center middle;
|
||||||
|
color: #888888;
|
||||||
|
}
|
||||||
|
|
||||||
|
#input-prefix.prefix-hidden {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chat-input {
|
||||||
|
width: 85%;
|
||||||
|
height: 5;
|
||||||
|
min-height: 5;
|
||||||
|
background: #3a3a3a;
|
||||||
|
border: none;
|
||||||
|
padding: 1 2;
|
||||||
|
color: #ffffff;
|
||||||
|
content-align: left top;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chat-input:focus {
|
||||||
|
background: #404040;
|
||||||
|
}
|
||||||
|
|
||||||
|
#command-dropdown {
|
||||||
|
display: none;
|
||||||
|
dock: bottom;
|
||||||
|
offset-y: -5;
|
||||||
|
offset-x: 7.5%;
|
||||||
|
height: auto;
|
||||||
|
max-height: 12;
|
||||||
|
width: 85%;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: solid #555555;
|
||||||
|
padding: 0;
|
||||||
|
layer: overlay;
|
||||||
|
}
|
||||||
|
|
||||||
|
#command-dropdown.visible {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
#command-dropdown #command-list {
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: none;
|
||||||
|
scrollbar-background: #2d2d2d;
|
||||||
|
scrollbar-color: #555555;
|
||||||
|
}
|
||||||
|
|
||||||
|
Footer {
|
||||||
|
dock: bottom;
|
||||||
|
height: auto;
|
||||||
|
background: #252525;
|
||||||
|
color: #888888;
|
||||||
|
padding: 0 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Button styles */
|
||||||
|
Button {
|
||||||
|
height: 3;
|
||||||
|
min-width: 10;
|
||||||
|
background: #3a3a3a;
|
||||||
|
color: #cccccc;
|
||||||
|
border: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button:hover {
|
||||||
|
background: #4a4a4a;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button:focus {
|
||||||
|
background: #505050;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button.-primary {
|
||||||
|
background: #3a3a3a;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button.-success {
|
||||||
|
background: #2d5016;
|
||||||
|
color: #90ee90;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button.-success:hover {
|
||||||
|
background: #3a6b1e;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button.-error {
|
||||||
|
background: #5a1a1a;
|
||||||
|
color: #ff6b6b;
|
||||||
|
}
|
||||||
|
|
||||||
|
Button.-error:hover {
|
||||||
|
background: #6e2222;
|
||||||
|
}
|
||||||
17
oai/tui/widgets/__init__.py
Normal file
17
oai/tui/widgets/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""TUI widgets for oAI."""
|
||||||
|
|
||||||
|
from oai.tui.widgets.chat_display import ChatDisplay
|
||||||
|
from oai.tui.widgets.footer import Footer
|
||||||
|
from oai.tui.widgets.header import Header
|
||||||
|
from oai.tui.widgets.input_bar import InputBar
|
||||||
|
from oai.tui.widgets.message import AssistantMessageWidget, SystemMessageWidget, UserMessageWidget
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChatDisplay",
|
||||||
|
"Footer",
|
||||||
|
"Header",
|
||||||
|
"InputBar",
|
||||||
|
"UserMessageWidget",
|
||||||
|
"SystemMessageWidget",
|
||||||
|
"AssistantMessageWidget",
|
||||||
|
]
|
||||||
21
oai/tui/widgets/chat_display.py
Normal file
21
oai/tui/widgets/chat_display.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Chat display widget for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.containers import ScrollableContainer
|
||||||
|
from textual.widgets import Static
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDisplay(ScrollableContainer):
|
||||||
|
"""Scrollable container for chat messages."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(id="chat-display")
|
||||||
|
|
||||||
|
async def add_message(self, widget: Static) -> None:
|
||||||
|
"""Add a message widget to the display."""
|
||||||
|
await self.mount(widget)
|
||||||
|
self.scroll_end(animate=False)
|
||||||
|
|
||||||
|
def clear_messages(self) -> None:
|
||||||
|
"""Clear all messages from the display."""
|
||||||
|
for child in list(self.children):
|
||||||
|
child.remove()
|
||||||
178
oai/tui/widgets/command_dropdown.py
Normal file
178
oai/tui/widgets/command_dropdown.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""Command dropdown menu for TUI input."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import VerticalScroll
|
||||||
|
from textual.widget import Widget
|
||||||
|
from textual.widgets import Label, OptionList
|
||||||
|
from textual.widgets.option_list import Option
|
||||||
|
|
||||||
|
from oai.commands import registry
|
||||||
|
|
||||||
|
|
||||||
|
class CommandDropdown(VerticalScroll):
|
||||||
|
"""Dropdown menu showing available commands."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
CommandDropdown {
|
||||||
|
display: none;
|
||||||
|
height: auto;
|
||||||
|
max-height: 12;
|
||||||
|
width: 80;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: solid #555555;
|
||||||
|
padding: 0;
|
||||||
|
layer: overlay;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandDropdown.visible {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandDropdown OptionList {
|
||||||
|
height: auto;
|
||||||
|
max-height: 12;
|
||||||
|
background: #2d2d2d;
|
||||||
|
border: none;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandDropdown OptionList > .option-list--option {
|
||||||
|
padding: 0 2;
|
||||||
|
color: #cccccc;
|
||||||
|
background: transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandDropdown OptionList > .option-list--option-highlighted {
|
||||||
|
background: #3e3e3e;
|
||||||
|
color: #ffffff;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the command dropdown."""
|
||||||
|
super().__init__(id="command-dropdown")
|
||||||
|
self._all_commands = []
|
||||||
|
self._load_commands()
|
||||||
|
|
||||||
|
def _load_commands(self) -> None:
|
||||||
|
"""Load all available commands."""
|
||||||
|
# Get base commands with descriptions
|
||||||
|
base_commands = [
|
||||||
|
("/help", "Show help screen"),
|
||||||
|
("/model", "Select AI model"),
|
||||||
|
("/stats", "Show session statistics"),
|
||||||
|
("/credits", "Check account credits"),
|
||||||
|
("/clear", "Clear chat display"),
|
||||||
|
("/reset", "Reset conversation history"),
|
||||||
|
("/memory on", "Enable conversation memory"),
|
||||||
|
("/memory off", "Disable memory"),
|
||||||
|
("/online on", "Enable online search"),
|
||||||
|
("/online off", "Disable online search"),
|
||||||
|
("/save", "Save current conversation"),
|
||||||
|
("/load", "Load saved conversation"),
|
||||||
|
("/list", "List saved conversations"),
|
||||||
|
("/delete", "Delete a conversation"),
|
||||||
|
("/export md", "Export as Markdown"),
|
||||||
|
("/export json", "Export as JSON"),
|
||||||
|
("/export html", "Export as HTML"),
|
||||||
|
("/prev", "Show previous message"),
|
||||||
|
("/next", "Show next message"),
|
||||||
|
("/config", "View configuration"),
|
||||||
|
("/config api", "Set API key"),
|
||||||
|
("/system", "Set system prompt"),
|
||||||
|
("/maxtoken", "Set token limit"),
|
||||||
|
("/retry", "Retry last prompt"),
|
||||||
|
("/paste", "Paste from clipboard"),
|
||||||
|
("/mcp on", "Enable MCP file access"),
|
||||||
|
("/mcp off", "Disable MCP"),
|
||||||
|
("/mcp status", "Show MCP status"),
|
||||||
|
("/mcp add", "Add folder/database"),
|
||||||
|
("/mcp remove", "Remove folder/database"),
|
||||||
|
("/mcp list", "List folders"),
|
||||||
|
("/mcp write on", "Enable write mode"),
|
||||||
|
("/mcp write off", "Disable write mode"),
|
||||||
|
("/mcp files", "Switch to file mode"),
|
||||||
|
("/mcp db list", "List databases"),
|
||||||
|
]
|
||||||
|
|
||||||
|
self._all_commands = base_commands
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the dropdown."""
|
||||||
|
yield OptionList(id="command-list")
|
||||||
|
|
||||||
|
def show_commands(self, filter_text: str = "") -> None:
|
||||||
|
"""Show commands matching the filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_text: Text to filter commands by
|
||||||
|
"""
|
||||||
|
option_list = self.query_one("#command-list", OptionList)
|
||||||
|
option_list.clear_options()
|
||||||
|
|
||||||
|
if not filter_text.startswith("/"):
|
||||||
|
self.remove_class("visible")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove the leading slash for filtering
|
||||||
|
filter_without_slash = filter_text[1:].lower()
|
||||||
|
|
||||||
|
# Filter commands - show if filter text is contained anywhere in the command
|
||||||
|
if filter_without_slash:
|
||||||
|
matching = [
|
||||||
|
(cmd, desc) for cmd, desc in self._all_commands
|
||||||
|
if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Show all commands when just "/" is typed
|
||||||
|
matching = self._all_commands
|
||||||
|
|
||||||
|
if not matching:
|
||||||
|
self.remove_class("visible")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add options - limit to 10 results
|
||||||
|
for cmd, desc in matching[:10]:
|
||||||
|
# Format: command in white, description in gray, separated by spaces
|
||||||
|
label = f"{cmd} [dim]{desc}[/]" if desc else cmd
|
||||||
|
option_list.add_option(Option(label, id=cmd))
|
||||||
|
|
||||||
|
self.add_class("visible")
|
||||||
|
|
||||||
|
# Auto-select first option
|
||||||
|
if len(option_list._options) > 0:
|
||||||
|
option_list.highlighted = 0
|
||||||
|
|
||||||
|
def hide(self) -> None:
|
||||||
|
"""Hide the dropdown."""
|
||||||
|
self.remove_class("visible")
|
||||||
|
|
||||||
|
def get_selected_command(self) -> str | None:
|
||||||
|
"""Get the currently selected command.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected command text or None
|
||||||
|
"""
|
||||||
|
option_list = self.query_one("#command-list", OptionList)
|
||||||
|
if option_list.highlighted is not None:
|
||||||
|
option = option_list.get_option_at_index(option_list.highlighted)
|
||||||
|
return option.id
|
||||||
|
return None
|
||||||
|
|
||||||
|
def move_selection_up(self) -> None:
|
||||||
|
"""Move selection up in the list."""
|
||||||
|
option_list = self.query_one("#command-list", OptionList)
|
||||||
|
if option_list.option_count > 0:
|
||||||
|
if option_list.highlighted is None:
|
||||||
|
option_list.highlighted = option_list.option_count - 1
|
||||||
|
elif option_list.highlighted > 0:
|
||||||
|
option_list.highlighted -= 1
|
||||||
|
|
||||||
|
def move_selection_down(self) -> None:
|
||||||
|
"""Move selection down in the list."""
|
||||||
|
option_list = self.query_one("#command-list", OptionList)
|
||||||
|
if option_list.option_count > 0:
|
||||||
|
if option_list.highlighted is None:
|
||||||
|
option_list.highlighted = 0
|
||||||
|
elif option_list.highlighted < option_list.option_count - 1:
|
||||||
|
option_list.highlighted += 1
|
||||||
58
oai/tui/widgets/command_suggester.py
Normal file
58
oai/tui/widgets/command_suggester.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Command suggester for TUI input."""
|
||||||
|
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from textual.suggester import Suggester
|
||||||
|
|
||||||
|
from oai.commands import registry
|
||||||
|
|
||||||
|
|
||||||
|
class CommandSuggester(Suggester):
|
||||||
|
"""Suggester that provides command completions."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the command suggester."""
|
||||||
|
super().__init__(use_cache=False, case_sensitive=False)
|
||||||
|
# Get all command names from registry
|
||||||
|
self._commands = []
|
||||||
|
self._update_commands()
|
||||||
|
|
||||||
|
def _update_commands(self) -> None:
|
||||||
|
"""Update the list of available commands."""
|
||||||
|
# Get all registered command names
|
||||||
|
command_names = registry.get_all_names()
|
||||||
|
# Add common MCP subcommands for better UX
|
||||||
|
mcp_subcommands = [
|
||||||
|
"/mcp on",
|
||||||
|
"/mcp off",
|
||||||
|
"/mcp status",
|
||||||
|
"/mcp add",
|
||||||
|
"/mcp remove",
|
||||||
|
"/mcp list",
|
||||||
|
"/mcp write on",
|
||||||
|
"/mcp write off",
|
||||||
|
"/mcp files",
|
||||||
|
"/mcp db list",
|
||||||
|
]
|
||||||
|
self._commands = command_names + mcp_subcommands
|
||||||
|
|
||||||
|
async def get_suggestion(self, value: str) -> str | None:
|
||||||
|
"""Get a command suggestion based on the current input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Current input value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Suggested completion or None
|
||||||
|
"""
|
||||||
|
if not value or not value.startswith("/"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the first command that starts with the input
|
||||||
|
value_lower = value.lower()
|
||||||
|
for cmd in self._commands:
|
||||||
|
if cmd.lower().startswith(value_lower) and cmd.lower() != value_lower:
|
||||||
|
# Return the rest of the command (after what's already typed)
|
||||||
|
return cmd[len(value):]
|
||||||
|
|
||||||
|
return None
|
||||||
39
oai/tui/widgets/footer.py
Normal file
39
oai/tui/widgets/footer.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""Footer widget for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.widgets import Static
|
||||||
|
|
||||||
|
|
||||||
|
class Footer(Static):
|
||||||
|
"""Footer displaying session metrics."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.tokens_in = 0
|
||||||
|
self.tokens_out = 0
|
||||||
|
self.cost = 0.0
|
||||||
|
self.messages = 0
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the footer."""
|
||||||
|
yield Static(self._format_footer(), id="footer-content")
|
||||||
|
|
||||||
|
def _format_footer(self) -> str:
|
||||||
|
"""Format the footer text."""
|
||||||
|
return (
|
||||||
|
f"[dim]Messages: {self.messages} | "
|
||||||
|
f"Tokens: {self.tokens_in + self.tokens_out:,} "
|
||||||
|
f"({self.tokens_in:,} in, {self.tokens_out:,} out) | "
|
||||||
|
f"Cost: ${self.cost:.4f}[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_stats(
|
||||||
|
self, tokens_in: int, tokens_out: int, cost: float, messages: int
|
||||||
|
) -> None:
|
||||||
|
"""Update the displayed statistics."""
|
||||||
|
self.tokens_in = tokens_in
|
||||||
|
self.tokens_out = tokens_out
|
||||||
|
self.cost = cost
|
||||||
|
self.messages = messages
|
||||||
|
content = self.query_one("#footer-content", Static)
|
||||||
|
content.update(self._format_footer())
|
||||||
65
oai/tui/widgets/header.py
Normal file
65
oai/tui/widgets/header.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""Header widget for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.widgets import Static
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class Header(Static):
|
||||||
|
"""Header displaying app title, version, current model, and capabilities."""
|
||||||
|
|
||||||
|
def __init__(self, version: str = "3.0.1", model: str = "", model_info: Optional[Dict[str, Any]] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.version = version
|
||||||
|
self.model = model
|
||||||
|
self.model_info = model_info or {}
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the header."""
|
||||||
|
yield Static(self._format_header(), id="header-content")
|
||||||
|
|
||||||
|
def _format_capabilities(self) -> str:
|
||||||
|
"""Format capability icons based on model info."""
|
||||||
|
if not self.model_info:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
icons = []
|
||||||
|
|
||||||
|
# Check vision support
|
||||||
|
architecture = self.model_info.get("architecture", {})
|
||||||
|
modality = architecture.get("modality", "")
|
||||||
|
if "image" in modality:
|
||||||
|
icons.append("[bold cyan]👁️[/]") # Bright if supported
|
||||||
|
else:
|
||||||
|
icons.append("[dim]👁️[/]") # Dim if not supported
|
||||||
|
|
||||||
|
# Check tool support
|
||||||
|
supported_params = self.model_info.get("supported_parameters", [])
|
||||||
|
if "tools" in supported_params or "tool_choice" in supported_params:
|
||||||
|
icons.append("[bold cyan]🔧[/]")
|
||||||
|
else:
|
||||||
|
icons.append("[dim]🔧[/]")
|
||||||
|
|
||||||
|
# Check online support (most models support :online suffix)
|
||||||
|
model_id = self.model_info.get("id", "")
|
||||||
|
if ":online" in model_id or model_id not in ["openrouter/free"]:
|
||||||
|
icons.append("[bold cyan]🌐[/]")
|
||||||
|
else:
|
||||||
|
icons.append("[dim]🌐[/]")
|
||||||
|
|
||||||
|
return " ".join(icons) if icons else ""
|
||||||
|
|
||||||
|
def _format_header(self) -> str:
|
||||||
|
"""Format the header text."""
|
||||||
|
model_text = f" | {self.model}" if self.model else ""
|
||||||
|
capabilities = self._format_capabilities()
|
||||||
|
capabilities_text = f" {capabilities}" if capabilities else ""
|
||||||
|
return f"[bold cyan]oAI[/] [dim]v{self.version}[/]{model_text}{capabilities_text}"
|
||||||
|
|
||||||
|
def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None) -> None:
|
||||||
|
"""Update the displayed model and capabilities."""
|
||||||
|
self.model = model
|
||||||
|
if model_info:
|
||||||
|
self.model_info = model_info
|
||||||
|
content = self.query_one("#header-content", Static)
|
||||||
|
content.update(self._format_header())
|
||||||
49
oai/tui/widgets/input_bar.py
Normal file
49
oai/tui/widgets/input_bar.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Input bar widget for oAI TUI."""
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Horizontal
|
||||||
|
from textual.widgets import Input, Static
|
||||||
|
|
||||||
|
|
||||||
|
class InputBar(Horizontal):
|
||||||
|
"""Input bar with prompt prefix and text input."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(id="input-bar")
|
||||||
|
self.mcp_status = ""
|
||||||
|
self.online_mode = False
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the input bar."""
|
||||||
|
yield Static(self._format_prefix(), id="input-prefix", classes="prefix-hidden" if not (self.mcp_status or self.online_mode) else "")
|
||||||
|
yield Input(
|
||||||
|
placeholder="Type a message or /command...",
|
||||||
|
id="chat-input"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_prefix(self) -> str:
|
||||||
|
"""Format the input prefix with status indicators."""
|
||||||
|
indicators = []
|
||||||
|
if self.mcp_status:
|
||||||
|
indicators.append(f"[cyan]{self.mcp_status}[/]")
|
||||||
|
if self.online_mode:
|
||||||
|
indicators.append("[green]🌐[/]")
|
||||||
|
|
||||||
|
prefix = " ".join(indicators) + " " if indicators else ""
|
||||||
|
return f"{prefix}[bold]>[/]"
|
||||||
|
|
||||||
|
def update_mcp_status(self, status: str) -> None:
|
||||||
|
"""Update MCP status indicator."""
|
||||||
|
self.mcp_status = status
|
||||||
|
prefix = self.query_one("#input-prefix", Static)
|
||||||
|
prefix.update(self._format_prefix())
|
||||||
|
|
||||||
|
def update_online_mode(self, online: bool) -> None:
|
||||||
|
"""Update online mode indicator."""
|
||||||
|
self.online_mode = online
|
||||||
|
prefix = self.query_one("#input-prefix", Static)
|
||||||
|
prefix.update(self._format_prefix())
|
||||||
|
|
||||||
|
def get_input(self) -> Input:
|
||||||
|
"""Get the input widget."""
|
||||||
|
return self.query_one("#chat-input", Input)
|
||||||
92
oai/tui/widgets/message.py
Normal file
92
oai/tui/widgets/message.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Message widgets for oAI TUI."""
|
||||||
|
|
||||||
|
from typing import Any, AsyncIterator, Tuple
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.style import Style
|
||||||
|
from rich.theme import Theme
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.widgets import RichLog, Static
|
||||||
|
|
||||||
|
# Custom theme for Markdown rendering - neutral colors matching the dark theme
|
||||||
|
MARKDOWN_THEME = Theme({
|
||||||
|
"markdown.text": Style(color="#cccccc"),
|
||||||
|
"markdown.paragraph": Style(color="#cccccc"),
|
||||||
|
"markdown.code": Style(color="#e0e0e0", bgcolor="#2a2a2a"),
|
||||||
|
"markdown.code_block": Style(color="#e0e0e0", bgcolor="#2a2a2a"),
|
||||||
|
"markdown.heading": Style(color="#ffffff", bold=True),
|
||||||
|
"markdown.h1": Style(color="#ffffff", bold=True),
|
||||||
|
"markdown.h2": Style(color="#eeeeee", bold=True),
|
||||||
|
"markdown.h3": Style(color="#dddddd", bold=True),
|
||||||
|
"markdown.link": Style(color="#aaaaaa", underline=False),
|
||||||
|
"markdown.link_url": Style(color="#888888"),
|
||||||
|
"markdown.emphasis": Style(color="#cccccc", italic=True),
|
||||||
|
"markdown.strong": Style(color="#ffffff", bold=True),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class UserMessageWidget(Static):
|
||||||
|
"""Widget for displaying user messages."""
|
||||||
|
|
||||||
|
def __init__(self, content: str):
|
||||||
|
super().__init__()
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the user message."""
|
||||||
|
yield Static(f"[bold green]You:[/] {self.content}")
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessageWidget(Static):
|
||||||
|
"""Widget for displaying system/info messages without 'You:' prefix."""
|
||||||
|
|
||||||
|
def __init__(self, content: str):
|
||||||
|
super().__init__()
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the system message."""
|
||||||
|
yield Static(self.content)
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantMessageWidget(Static):
|
||||||
|
"""Widget for displaying assistant responses with streaming support."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "Assistant"):
|
||||||
|
super().__init__()
|
||||||
|
self.model_name = model_name
|
||||||
|
self.full_text = ""
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
"""Compose the assistant message."""
|
||||||
|
yield Static(f"[bold]{self.model_name}:[/]", id="assistant-label")
|
||||||
|
yield RichLog(id="assistant-content", highlight=True, markup=True, wrap=True)
|
||||||
|
|
||||||
|
async def stream_response(self, response_iterator: AsyncIterator) -> Tuple[str, Any]:
|
||||||
|
"""Stream tokens progressively and return final text and usage."""
|
||||||
|
log = self.query_one("#assistant-content", RichLog)
|
||||||
|
self.full_text = ""
|
||||||
|
usage = None
|
||||||
|
|
||||||
|
async for chunk in response_iterator:
|
||||||
|
if hasattr(chunk, "delta_content") and chunk.delta_content:
|
||||||
|
self.full_text += chunk.delta_content
|
||||||
|
log.clear()
|
||||||
|
# Use neutral code theme for syntax highlighting
|
||||||
|
md = Markdown(self.full_text, code_theme="github-dark", inline_code_theme="github-dark")
|
||||||
|
log.write(md)
|
||||||
|
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
usage = chunk.usage
|
||||||
|
|
||||||
|
return self.full_text, usage
|
||||||
|
|
||||||
|
def set_content(self, content: str) -> None:
|
||||||
|
"""Set the complete content (non-streaming)."""
|
||||||
|
self.full_text = content
|
||||||
|
log = self.query_one("#assistant-content", RichLog)
|
||||||
|
log.clear()
|
||||||
|
# Use neutral code theme for syntax highlighting
|
||||||
|
md = Markdown(content, code_theme="github-dark", inline_code_theme="github-dark")
|
||||||
|
log.write(md)
|
||||||
20
oai/utils/__init__.py
Normal file
20
oai/utils/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
Utility modules for oAI.
|
||||||
|
|
||||||
|
This package provides common utilities used throughout the application
|
||||||
|
including logging, file handling, and export functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.utils.logging import setup_logging, get_logger
|
||||||
|
from oai.utils.files import read_file_safe, is_binary_file
|
||||||
|
from oai.utils.export import export_as_markdown, export_as_json, export_as_html
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"setup_logging",
|
||||||
|
"get_logger",
|
||||||
|
"read_file_safe",
|
||||||
|
"is_binary_file",
|
||||||
|
"export_as_markdown",
|
||||||
|
"export_as_json",
|
||||||
|
"export_as_html",
|
||||||
|
]
|
||||||
248
oai/utils/export.py
Normal file
248
oai/utils/export.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
"""
|
||||||
|
Export utilities for oAI.
|
||||||
|
|
||||||
|
This module provides functions for exporting conversation history
|
||||||
|
in various formats including Markdown, JSON, and HTML.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
from typing import List, Dict
|
||||||
|
from html import escape as html_escape
|
||||||
|
|
||||||
|
from oai.constants import APP_VERSION, APP_URL
|
||||||
|
|
||||||
|
|
||||||
|
def export_as_markdown(
|
||||||
|
session_history: List[Dict[str, str]],
|
||||||
|
session_system_prompt: str = ""
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Export conversation history as Markdown.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_history: List of message dictionaries with 'prompt' and 'response'
|
||||||
|
session_system_prompt: Optional system prompt to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown formatted string
|
||||||
|
"""
|
||||||
|
lines = ["# Conversation Export", ""]
|
||||||
|
|
||||||
|
if session_system_prompt:
|
||||||
|
lines.extend([f"**System Prompt:** {session_system_prompt}", ""])
|
||||||
|
|
||||||
|
lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
lines.append(f"**Messages:** {len(session_history)}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
for i, entry in enumerate(session_history, 1):
|
||||||
|
lines.append(f"## Message {i}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("**User:**")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(entry.get("prompt", ""))
|
||||||
|
lines.append("")
|
||||||
|
lines.append("**Assistant:**")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(entry.get("response", ""))
|
||||||
|
lines.append("")
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.append(f"*Exported from oAI v{APP_VERSION} - {APP_URL}*")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def export_as_json(
|
||||||
|
session_history: List[Dict[str, str]],
|
||||||
|
session_system_prompt: str = ""
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Export conversation history as JSON.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_history: List of message dictionaries
|
||||||
|
session_system_prompt: Optional system prompt to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON formatted string
|
||||||
|
"""
|
||||||
|
export_data = {
|
||||||
|
"export_date": datetime.datetime.now().isoformat(),
|
||||||
|
"app_version": APP_VERSION,
|
||||||
|
"system_prompt": session_system_prompt,
|
||||||
|
"message_count": len(session_history),
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"index": i + 1,
|
||||||
|
"prompt": entry.get("prompt", ""),
|
||||||
|
"response": entry.get("response", ""),
|
||||||
|
"prompt_tokens": entry.get("prompt_tokens", 0),
|
||||||
|
"completion_tokens": entry.get("completion_tokens", 0),
|
||||||
|
"cost": entry.get("msg_cost", 0.0),
|
||||||
|
}
|
||||||
|
for i, entry in enumerate(session_history)
|
||||||
|
],
|
||||||
|
"totals": {
|
||||||
|
"prompt_tokens": sum(e.get("prompt_tokens", 0) for e in session_history),
|
||||||
|
"completion_tokens": sum(e.get("completion_tokens", 0) for e in session_history),
|
||||||
|
"total_cost": sum(e.get("msg_cost", 0.0) for e in session_history),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return json.dumps(export_data, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def export_as_html(
|
||||||
|
session_history: List[Dict[str, str]],
|
||||||
|
session_system_prompt: str = ""
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Export conversation history as styled HTML.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_history: List of message dictionaries
|
||||||
|
session_system_prompt: Optional system prompt to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTML formatted string with embedded CSS
|
||||||
|
"""
|
||||||
|
html_parts = [
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<html>",
|
||||||
|
"<head>",
|
||||||
|
" <meta charset='UTF-8'>",
|
||||||
|
" <meta name='viewport' content='width=device-width, initial-scale=1.0'>",
|
||||||
|
" <title>Conversation Export - oAI</title>",
|
||||||
|
" <style>",
|
||||||
|
" * { box-sizing: border-box; }",
|
||||||
|
" body {",
|
||||||
|
" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;",
|
||||||
|
" max-width: 900px;",
|
||||||
|
" margin: 40px auto;",
|
||||||
|
" padding: 20px;",
|
||||||
|
" background: #f5f5f5;",
|
||||||
|
" color: #333;",
|
||||||
|
" }",
|
||||||
|
" .header {",
|
||||||
|
" background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);",
|
||||||
|
" color: white;",
|
||||||
|
" padding: 30px;",
|
||||||
|
" border-radius: 10px;",
|
||||||
|
" margin-bottom: 30px;",
|
||||||
|
" box-shadow: 0 4px 6px rgba(0,0,0,0.1);",
|
||||||
|
" }",
|
||||||
|
" .header h1 {",
|
||||||
|
" margin: 0 0 10px 0;",
|
||||||
|
" font-size: 2em;",
|
||||||
|
" }",
|
||||||
|
" .export-info {",
|
||||||
|
" opacity: 0.9;",
|
||||||
|
" font-size: 0.95em;",
|
||||||
|
" margin: 5px 0;",
|
||||||
|
" }",
|
||||||
|
" .system-prompt {",
|
||||||
|
" background: #fff3cd;",
|
||||||
|
" padding: 20px;",
|
||||||
|
" border-radius: 8px;",
|
||||||
|
" margin-bottom: 25px;",
|
||||||
|
" border-left: 5px solid #ffc107;",
|
||||||
|
" box-shadow: 0 2px 4px rgba(0,0,0,0.05);",
|
||||||
|
" }",
|
||||||
|
" .system-prompt strong {",
|
||||||
|
" color: #856404;",
|
||||||
|
" display: block;",
|
||||||
|
" margin-bottom: 10px;",
|
||||||
|
" font-size: 1.1em;",
|
||||||
|
" }",
|
||||||
|
" .message-container { margin-bottom: 20px; }",
|
||||||
|
" .message {",
|
||||||
|
" background: white;",
|
||||||
|
" padding: 20px;",
|
||||||
|
" border-radius: 8px;",
|
||||||
|
" box-shadow: 0 2px 4px rgba(0,0,0,0.08);",
|
||||||
|
" margin-bottom: 12px;",
|
||||||
|
" }",
|
||||||
|
" .user-message { border-left: 5px solid #10b981; }",
|
||||||
|
" .assistant-message { border-left: 5px solid #3b82f6; }",
|
||||||
|
" .role {",
|
||||||
|
" font-weight: bold;",
|
||||||
|
" margin-bottom: 12px;",
|
||||||
|
" font-size: 1.05em;",
|
||||||
|
" text-transform: uppercase;",
|
||||||
|
" letter-spacing: 0.5px;",
|
||||||
|
" }",
|
||||||
|
" .user-role { color: #10b981; }",
|
||||||
|
" .assistant-role { color: #3b82f6; }",
|
||||||
|
" .content {",
|
||||||
|
" line-height: 1.8;",
|
||||||
|
" white-space: pre-wrap;",
|
||||||
|
" color: #333;",
|
||||||
|
" }",
|
||||||
|
" .message-number {",
|
||||||
|
" color: #6b7280;",
|
||||||
|
" font-size: 0.85em;",
|
||||||
|
" margin-bottom: 15px;",
|
||||||
|
" font-weight: 600;",
|
||||||
|
" }",
|
||||||
|
" .footer {",
|
||||||
|
" text-align: center;",
|
||||||
|
" margin-top: 40px;",
|
||||||
|
" padding: 20px;",
|
||||||
|
" color: #6b7280;",
|
||||||
|
" font-size: 0.9em;",
|
||||||
|
" }",
|
||||||
|
" .footer a { color: #667eea; text-decoration: none; }",
|
||||||
|
" .footer a:hover { text-decoration: underline; }",
|
||||||
|
" @media print {",
|
||||||
|
" body { background: white; }",
|
||||||
|
" .message { break-inside: avoid; }",
|
||||||
|
" }",
|
||||||
|
" </style>",
|
||||||
|
"</head>",
|
||||||
|
"<body>",
|
||||||
|
" <div class='header'>",
|
||||||
|
" <h1>Conversation Export</h1>",
|
||||||
|
f" <div class='export-info'>Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</div>",
|
||||||
|
f" <div class='export-info'>Total Messages: {len(session_history)}</div>",
|
||||||
|
" </div>",
|
||||||
|
]
|
||||||
|
|
||||||
|
if session_system_prompt:
|
||||||
|
html_parts.extend([
|
||||||
|
" <div class='system-prompt'>",
|
||||||
|
" <strong>System Prompt</strong>",
|
||||||
|
f" <div>{html_escape(session_system_prompt)}</div>",
|
||||||
|
" </div>",
|
||||||
|
])
|
||||||
|
|
||||||
|
for i, entry in enumerate(session_history, 1):
|
||||||
|
prompt = html_escape(entry.get("prompt", ""))
|
||||||
|
response = html_escape(entry.get("response", ""))
|
||||||
|
|
||||||
|
html_parts.extend([
|
||||||
|
" <div class='message-container'>",
|
||||||
|
f" <div class='message-number'>Message {i} of {len(session_history)}</div>",
|
||||||
|
" <div class='message user-message'>",
|
||||||
|
" <div class='role user-role'>User</div>",
|
||||||
|
f" <div class='content'>{prompt}</div>",
|
||||||
|
" </div>",
|
||||||
|
" <div class='message assistant-message'>",
|
||||||
|
" <div class='role assistant-role'>Assistant</div>",
|
||||||
|
f" <div class='content'>{response}</div>",
|
||||||
|
" </div>",
|
||||||
|
" </div>",
|
||||||
|
])
|
||||||
|
|
||||||
|
html_parts.extend([
|
||||||
|
" <div class='footer'>",
|
||||||
|
f" <p>Generated by oAI v{APP_VERSION} • <a href='{APP_URL}'>{APP_URL}</a></p>",
|
||||||
|
" </div>",
|
||||||
|
"</body>",
|
||||||
|
"</html>",
|
||||||
|
])
|
||||||
|
|
||||||
|
return "\n".join(html_parts)
|
||||||
323
oai/utils/files.py
Normal file
323
oai/utils/files.py
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
"""
|
||||||
|
File handling utilities for oAI.
|
||||||
|
|
||||||
|
This module provides safe file reading, type detection, and other
|
||||||
|
file-related operations used throughout the application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import mimetypes
|
||||||
|
import base64
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any, Tuple
|
||||||
|
|
||||||
|
from oai.constants import (
|
||||||
|
MAX_FILE_SIZE,
|
||||||
|
CONTENT_TRUNCATION_THRESHOLD,
|
||||||
|
SUPPORTED_CODE_EXTENSIONS,
|
||||||
|
ALLOWED_FILE_EXTENSIONS,
|
||||||
|
)
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
def is_binary_file(file_path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a file appears to be binary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the file appears to be binary, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
# Read first 8KB to check for binary content
|
||||||
|
chunk = f.read(8192)
|
||||||
|
# Check for null bytes (common in binary files)
|
||||||
|
if b"\x00" in chunk:
|
||||||
|
return True
|
||||||
|
# Try to decode as UTF-8
|
||||||
|
try:
|
||||||
|
chunk.decode("utf-8")
|
||||||
|
return False
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_type(file_path: Path) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
Determine the MIME type and category of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (mime_type, category) where category is one of:
|
||||||
|
'image', 'pdf', 'code', 'text', 'binary', 'unknown'
|
||||||
|
"""
|
||||||
|
mime_type, _ = mimetypes.guess_type(str(file_path))
|
||||||
|
ext = file_path.suffix.lower()
|
||||||
|
|
||||||
|
if mime_type and mime_type.startswith("image/"):
|
||||||
|
return mime_type, "image"
|
||||||
|
elif mime_type == "application/pdf" or ext == ".pdf":
|
||||||
|
return mime_type or "application/pdf", "pdf"
|
||||||
|
elif ext in SUPPORTED_CODE_EXTENSIONS:
|
||||||
|
return mime_type or "text/plain", "code"
|
||||||
|
elif mime_type and mime_type.startswith("text/"):
|
||||||
|
return mime_type, "text"
|
||||||
|
elif is_binary_file(file_path):
|
||||||
|
return mime_type, "binary"
|
||||||
|
else:
|
||||||
|
return mime_type, "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def read_file_safe(
|
||||||
|
file_path: Path,
|
||||||
|
max_size: int = MAX_FILE_SIZE,
|
||||||
|
truncate_threshold: int = CONTENT_TRUNCATION_THRESHOLD
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Safely read a file with size limits and truncation support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to read
|
||||||
|
max_size: Maximum file size to read (bytes)
|
||||||
|
truncate_threshold: Threshold for truncating large files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- content: File content (text or base64)
|
||||||
|
- size: File size in bytes
|
||||||
|
- truncated: Whether content was truncated
|
||||||
|
- encoding: 'text', 'base64', or None on error
|
||||||
|
- error: Error message if reading failed
|
||||||
|
"""
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
try:
|
||||||
|
path = Path(file_path).resolve()
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
return {
|
||||||
|
"content": None,
|
||||||
|
"size": 0,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": None,
|
||||||
|
"error": f"File not found: {path}"
|
||||||
|
}
|
||||||
|
|
||||||
|
if not path.is_file():
|
||||||
|
return {
|
||||||
|
"content": None,
|
||||||
|
"size": 0,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": None,
|
||||||
|
"error": f"Not a file: {path}"
|
||||||
|
}
|
||||||
|
|
||||||
|
file_size = path.stat().st_size
|
||||||
|
|
||||||
|
if file_size > max_size:
|
||||||
|
return {
|
||||||
|
"content": None,
|
||||||
|
"size": file_size,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": None,
|
||||||
|
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max: {max_size / (1024*1024):.0f}MB)"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try to read as text first
|
||||||
|
try:
|
||||||
|
content = path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# Check if truncation is needed
|
||||||
|
if file_size > truncate_threshold:
|
||||||
|
lines = content.split("\n")
|
||||||
|
total_lines = len(lines)
|
||||||
|
|
||||||
|
# Keep first 500 lines and last 100 lines
|
||||||
|
head_lines = 500
|
||||||
|
tail_lines = 100
|
||||||
|
|
||||||
|
if total_lines > (head_lines + tail_lines):
|
||||||
|
truncated_content = (
|
||||||
|
"\n".join(lines[:head_lines]) +
|
||||||
|
f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted] ...\n\n" +
|
||||||
|
"\n".join(lines[-tail_lines:])
|
||||||
|
)
|
||||||
|
logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)")
|
||||||
|
return {
|
||||||
|
"content": truncated_content,
|
||||||
|
"size": file_size,
|
||||||
|
"truncated": True,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"lines_shown": head_lines + tail_lines,
|
||||||
|
"encoding": "text",
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Read file: {path} ({file_size} bytes)")
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"size": file_size,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": "text",
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# File is binary, return base64 encoded
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
binary_data = f.read()
|
||||||
|
b64_content = base64.b64encode(binary_data).decode("utf-8")
|
||||||
|
logger.info(f"Read binary file: {path} ({file_size} bytes)")
|
||||||
|
return {
|
||||||
|
"content": b64_content,
|
||||||
|
"size": file_size,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": "base64",
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
except PermissionError as e:
|
||||||
|
return {
|
||||||
|
"content": None,
|
||||||
|
"size": 0,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": None,
|
||||||
|
"error": f"Permission denied: {e}"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading file {file_path}: {e}")
|
||||||
|
return {
|
||||||
|
"content": None,
|
||||||
|
"size": 0,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": None,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_extension(file_path: Path) -> str:
|
||||||
|
"""
|
||||||
|
Get the lowercase file extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Lowercase extension including the dot (e.g., '.py')
|
||||||
|
"""
|
||||||
|
return file_path.suffix.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def is_allowed_extension(file_path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a file has an allowed extension for attachment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the extension is allowed, False otherwise
|
||||||
|
"""
|
||||||
|
return get_file_extension(file_path) in ALLOWED_FILE_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def format_file_size(size_bytes: int) -> str:
|
||||||
|
"""
|
||||||
|
Format a file size in human-readable format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size_bytes: Size in bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string (e.g., '1.5 MB', '512 KB')
|
||||||
|
"""
|
||||||
|
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||||
|
if abs(size_bytes) < 1024:
|
||||||
|
return f"{size_bytes:.1f} {unit}"
|
||||||
|
size_bytes /= 1024
|
||||||
|
return f"{size_bytes:.1f} PB"
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_file_attachment(
|
||||||
|
file_path: Path,
|
||||||
|
model_capabilities: Dict[str, Any]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Prepare a file for attachment to an API request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
model_capabilities: Model capability information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content block dictionary for the API, or None if unsupported
|
||||||
|
"""
|
||||||
|
logger = get_logger()
|
||||||
|
path = Path(file_path).resolve()
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
logger.warning(f"File not found: {path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
mime_type, category = get_file_type(path)
|
||||||
|
file_size = path.stat().st_size
|
||||||
|
|
||||||
|
if file_size > MAX_FILE_SIZE:
|
||||||
|
logger.warning(f"File too large: {path} ({format_file_size(file_size)})")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
file_data = f.read()
|
||||||
|
|
||||||
|
if category == "image":
|
||||||
|
# Check if model supports images
|
||||||
|
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
|
||||||
|
if "image" not in input_modalities:
|
||||||
|
logger.warning(f"Model does not support images")
|
||||||
|
return None
|
||||||
|
|
||||||
|
b64_data = base64.b64encode(file_data).decode("utf-8")
|
||||||
|
return {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:{mime_type};base64,{b64_data}"}
|
||||||
|
}
|
||||||
|
|
||||||
|
elif category == "pdf":
|
||||||
|
# Check if model supports PDFs
|
||||||
|
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
|
||||||
|
supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"])
|
||||||
|
if not supports_pdf:
|
||||||
|
logger.warning(f"Model does not support PDFs")
|
||||||
|
return None
|
||||||
|
|
||||||
|
b64_data = base64.b64encode(file_data).decode("utf-8")
|
||||||
|
return {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:application/pdf;base64,{b64_data}"}
|
||||||
|
}
|
||||||
|
|
||||||
|
elif category in ("code", "text"):
|
||||||
|
text_content = file_data.decode("utf-8")
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": f"File: {path.name}\n\n{text_content}"
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported file type: {category} ({mime_type})")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logger.error(f"Cannot decode file as UTF-8: {path}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error preparing file attachment {path}: {e}")
|
||||||
|
return None
|
||||||
297
oai/utils/logging.py
Normal file
297
oai/utils/logging.py
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
"""
|
||||||
|
Logging configuration for oAI.
|
||||||
|
|
||||||
|
This module provides centralized logging setup with Rich formatting,
|
||||||
|
file rotation, and configurable log levels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import datetime
|
||||||
|
import shutil
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
|
from oai.constants import (
|
||||||
|
LOG_FILE,
|
||||||
|
CONFIG_DIR,
|
||||||
|
DEFAULT_LOG_MAX_SIZE_MB,
|
||||||
|
DEFAULT_LOG_BACKUP_COUNT,
|
||||||
|
DEFAULT_LOG_LEVEL,
|
||||||
|
VALID_LOG_LEVELS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RotatingRichHandler(RotatingFileHandler):
|
||||||
|
"""
|
||||||
|
Custom log handler combining file rotation with Rich formatting.
|
||||||
|
|
||||||
|
This handler writes Rich-formatted log output to a rotating file,
|
||||||
|
providing colored and formatted logs even in file output while
|
||||||
|
managing file size and backups automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""Initialize the handler with Rich console for formatting."""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# Create an internal console for Rich formatting
|
||||||
|
self.rich_console = Console(
|
||||||
|
file=io.StringIO(),
|
||||||
|
width=120,
|
||||||
|
force_terminal=False
|
||||||
|
)
|
||||||
|
self.rich_handler = RichHandler(
|
||||||
|
console=self.rich_console,
|
||||||
|
show_time=True,
|
||||||
|
show_path=True,
|
||||||
|
rich_tracebacks=True,
|
||||||
|
tracebacks_suppress=["requests", "openrouter", "urllib3", "httpx", "openai"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit(self, record: logging.LogRecord) -> None:
|
||||||
|
"""
|
||||||
|
Emit a log record with Rich formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record: The log record to emit
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Format with Rich
|
||||||
|
self.rich_handler.emit(record)
|
||||||
|
output = self.rich_console.file.getvalue()
|
||||||
|
self.rich_console.file.seek(0)
|
||||||
|
self.rich_console.file.truncate(0)
|
||||||
|
|
||||||
|
if output:
|
||||||
|
self.stream.write(output)
|
||||||
|
self.flush()
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingManager:
|
||||||
|
"""
|
||||||
|
Manages application logging configuration.
|
||||||
|
|
||||||
|
Provides methods to setup, configure, and manage logging with
|
||||||
|
support for runtime reconfiguration and level changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the logging manager."""
|
||||||
|
self.handler: Optional[RotatingRichHandler] = None
|
||||||
|
self.app_logger: Optional[logging.Logger] = None
|
||||||
|
self.max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
|
||||||
|
self.backup_count: int = DEFAULT_LOG_BACKUP_COUNT
|
||||||
|
self.log_level: str = DEFAULT_LOG_LEVEL
|
||||||
|
|
||||||
|
def setup(
|
||||||
|
self,
|
||||||
|
max_size_mb: Optional[int] = None,
|
||||||
|
backup_count: Optional[int] = None,
|
||||||
|
log_level: Optional[str] = None
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Setup or reconfigure logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size_mb: Maximum log file size in MB
|
||||||
|
backup_count: Number of backup files to keep
|
||||||
|
log_level: Logging level string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configured application logger
|
||||||
|
"""
|
||||||
|
# Update configuration if provided
|
||||||
|
if max_size_mb is not None:
|
||||||
|
self.max_size_mb = max_size_mb
|
||||||
|
if backup_count is not None:
|
||||||
|
self.backup_count = backup_count
|
||||||
|
if log_level is not None:
|
||||||
|
self.log_level = log_level
|
||||||
|
|
||||||
|
# Ensure config directory exists
|
||||||
|
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Get root logger
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
|
||||||
|
# Remove existing handler if present
|
||||||
|
if self.handler is not None:
|
||||||
|
root_logger.removeHandler(self.handler)
|
||||||
|
try:
|
||||||
|
self.handler.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check if log needs manual rotation
|
||||||
|
self._check_rotation()
|
||||||
|
|
||||||
|
# Create new handler
|
||||||
|
max_bytes = self.max_size_mb * 1024 * 1024
|
||||||
|
self.handler = RotatingRichHandler(
|
||||||
|
filename=str(LOG_FILE),
|
||||||
|
maxBytes=max_bytes,
|
||||||
|
backupCount=self.backup_count,
|
||||||
|
encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.handler.setLevel(logging.NOTSET)
|
||||||
|
root_logger.setLevel(logging.WARNING)
|
||||||
|
root_logger.addHandler(self.handler)
|
||||||
|
|
||||||
|
# Suppress noisy third-party loggers
|
||||||
|
for logger_name in [
|
||||||
|
"asyncio", "urllib3", "requests", "httpx",
|
||||||
|
"httpcore", "openai", "openrouter"
|
||||||
|
]:
|
||||||
|
logging.getLogger(logger_name).setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# Configure application logger
|
||||||
|
self.app_logger = logging.getLogger("oai_app")
|
||||||
|
level = VALID_LOG_LEVELS.get(self.log_level.lower(), logging.INFO)
|
||||||
|
self.app_logger.setLevel(level)
|
||||||
|
self.app_logger.propagate = True
|
||||||
|
|
||||||
|
return self.app_logger
|
||||||
|
|
||||||
|
def _check_rotation(self) -> None:
|
||||||
|
"""Check if log file needs rotation and perform if necessary."""
|
||||||
|
if not LOG_FILE.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
current_size = LOG_FILE.stat().st_size
|
||||||
|
max_bytes = self.max_size_mb * 1024 * 1024
|
||||||
|
|
||||||
|
if current_size >= max_bytes:
|
||||||
|
# Perform manual rotation
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
backup_file = f"{LOG_FILE}.{timestamp}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.move(str(LOG_FILE), backup_file)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Clean old backups
|
||||||
|
self._cleanup_old_backups()
|
||||||
|
|
||||||
|
def _cleanup_old_backups(self) -> None:
|
||||||
|
"""Remove old backup files exceeding the backup count."""
|
||||||
|
log_dir = LOG_FILE.parent
|
||||||
|
backup_pattern = f"{LOG_FILE.name}.*"
|
||||||
|
backups = sorted(glob.glob(str(log_dir / backup_pattern)))
|
||||||
|
|
||||||
|
while len(backups) > self.backup_count:
|
||||||
|
oldest = backups.pop(0)
|
||||||
|
try:
|
||||||
|
os.remove(oldest)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_level(self, level: str) -> bool:
|
||||||
|
"""
|
||||||
|
Set the application log level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Log level string (debug/info/warning/error/critical)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if level was set successfully, False otherwise
|
||||||
|
"""
|
||||||
|
level_lower = level.lower()
|
||||||
|
if level_lower not in VALID_LOG_LEVELS:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.log_level = level_lower
|
||||||
|
if self.app_logger:
|
||||||
|
self.app_logger.setLevel(VALID_LOG_LEVELS[level_lower])
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_logger(self) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Get the application logger, initializing if necessary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The application logger
|
||||||
|
"""
|
||||||
|
if self.app_logger is None:
|
||||||
|
self.setup()
|
||||||
|
return self.app_logger
|
||||||
|
|
||||||
|
|
||||||
|
# Global logging manager instance
|
||||||
|
_logging_manager = LoggingManager()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(
|
||||||
|
max_size_mb: Optional[int] = None,
|
||||||
|
backup_count: Optional[int] = None,
|
||||||
|
log_level: Optional[str] = None
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Setup application logging.
|
||||||
|
|
||||||
|
This is the main entry point for configuring logging. Call this
|
||||||
|
early in application startup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size_mb: Maximum log file size in MB
|
||||||
|
backup_count: Number of backup files to keep
|
||||||
|
log_level: Logging level string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configured application logger
|
||||||
|
"""
|
||||||
|
return _logging_manager.setup(max_size_mb, backup_count, log_level)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger() -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Get the application logger.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The application logger instance
|
||||||
|
"""
|
||||||
|
return _logging_manager.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def set_log_level(level: str) -> bool:
|
||||||
|
"""
|
||||||
|
Set the application log level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Log level string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
return _logging_manager.set_level(level)
|
||||||
|
|
||||||
|
|
||||||
|
def reload_logging(
|
||||||
|
max_size_mb: Optional[int] = None,
|
||||||
|
backup_count: Optional[int] = None,
|
||||||
|
log_level: Optional[str] = None
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Reload logging configuration.
|
||||||
|
|
||||||
|
Useful when settings change at runtime.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size_mb: New maximum log file size
|
||||||
|
backup_count: New backup count
|
||||||
|
log_level: New log level
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reconfigured logger
|
||||||
|
"""
|
||||||
|
return _logging_manager.setup(max_size_mb, backup_count, log_level)
|
||||||
134
pyproject.toml
Normal file
134
pyproject.toml
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "oai"
|
||||||
|
version = "3.0.0-b2" # MUST match oai/__init__.py __version__
|
||||||
|
description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application"
|
||||||
|
readme = "README.md"
|
||||||
|
license = {text = "MIT"}
|
||||||
|
authors = [
|
||||||
|
{name = "Rune", email = "rune@example.com"}
|
||||||
|
]
|
||||||
|
maintainers = [
|
||||||
|
{name = "Rune", email = "rune@example.com"}
|
||||||
|
]
|
||||||
|
keywords = [
|
||||||
|
"ai",
|
||||||
|
"chat",
|
||||||
|
"openrouter",
|
||||||
|
"cli",
|
||||||
|
"terminal",
|
||||||
|
"mcp",
|
||||||
|
"llm",
|
||||||
|
]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Environment :: Console",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Utilities",
|
||||||
|
]
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"anyio>=4.0.0",
|
||||||
|
"click>=8.0.0",
|
||||||
|
"httpx>=0.24.0",
|
||||||
|
"markdown-it-py>=3.0.0",
|
||||||
|
"openrouter>=0.0.19",
|
||||||
|
"packaging>=21.0",
|
||||||
|
"pyperclip>=1.8.0",
|
||||||
|
"requests>=2.28.0",
|
||||||
|
"rich>=13.0.0",
|
||||||
|
"textual>=0.50.0",
|
||||||
|
"typer>=0.9.0",
|
||||||
|
"mcp>=1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.0.0",
|
||||||
|
"pytest-asyncio>=0.21.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"black>=23.0.0",
|
||||||
|
"isort>=5.12.0",
|
||||||
|
"mypy>=1.0.0",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://iurl.no/oai"
|
||||||
|
Repository = "https://gitlab.pm/rune/oai"
|
||||||
|
Documentation = "https://iurl.no/oai"
|
||||||
|
"Bug Tracker" = "https://gitlab.pm/rune/oai/issues"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
oai = "oai.cli:main"
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.tui", "oai.tui.widgets", "oai.tui.screens", "oai.utils"]
|
||||||
|
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
oai = ["py.typed"]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 100
|
||||||
|
target-version = ["py310", "py311", "py312"]
|
||||||
|
include = '\.pyi?$'
|
||||||
|
exclude = '''
|
||||||
|
/(
|
||||||
|
\.git
|
||||||
|
| \.mypy_cache
|
||||||
|
| \.pytest_cache
|
||||||
|
| \.venv
|
||||||
|
| build
|
||||||
|
| dist
|
||||||
|
)/
|
||||||
|
'''
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
line_length = 100
|
||||||
|
skip_gitignore = true
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.10"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
||||||
|
ignore_missing_imports = true
|
||||||
|
exclude = [
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
".venv",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py310"
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
"F", # Pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"UP", # pyupgrade
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by black)
|
||||||
|
"B008", # do not perform function calls in argument defaults
|
||||||
|
"C901", # too complex
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
addopts = "-v --tb=short"
|
||||||
@@ -1,37 +1,26 @@
|
|||||||
anyio==4.11.0
|
# oai.py v2.1.0-beta - Core Dependencies
|
||||||
beautifulsoup4==4.14.2
|
anyio>=4.11.0
|
||||||
charset-normalizer==3.4.4
|
charset-normalizer>=3.4.4
|
||||||
click==8.3.1
|
click>=8.3.1
|
||||||
docopt==0.6.2
|
h11>=0.16.0
|
||||||
h11==0.16.0
|
httpcore>=1.0.9
|
||||||
httpcore==1.0.9
|
httpx>=0.28.1
|
||||||
httpx==0.28.1
|
idna>=3.11
|
||||||
idna==3.11
|
markdown-it-py>=4.0.0
|
||||||
latex2mathml==3.78.1
|
mdurl>=0.1.2
|
||||||
loguru==0.7.3
|
openrouter>=0.0.19
|
||||||
markdown-it-py==4.0.0
|
packaging>=25.0
|
||||||
markdown2==2.5.4
|
prompt-toolkit>=3.0.52
|
||||||
mdurl==0.1.2
|
Pygments>=2.19.2
|
||||||
natsort==8.4.0
|
pyperclip>=1.11.0
|
||||||
openrouter==0.0.19
|
requests>=2.32.5
|
||||||
pipreqs==0.4.13
|
rich>=14.2.0
|
||||||
prompt_toolkit==3.0.52
|
shellingham>=1.5.4
|
||||||
Pygments==2.19.2
|
sniffio>=1.3.1
|
||||||
pyperclip==1.11.0
|
typer>=0.20.0
|
||||||
python-dateutil==2.9.0.post0
|
typing-extensions>=4.15.0
|
||||||
python-magic==0.4.27
|
urllib3>=2.5.0
|
||||||
PyYAML==6.0.3
|
wcwidth>=0.2.14
|
||||||
requests==2.32.5
|
|
||||||
rich==14.2.0
|
# MCP (Model Context Protocol)
|
||||||
shellingham==1.5.4
|
mcp>=1.25.0
|
||||||
six==1.17.0
|
|
||||||
sniffio==1.3.1
|
|
||||||
soupsieve==2.8
|
|
||||||
svgwrite==1.4.3
|
|
||||||
tqdm==4.67.1
|
|
||||||
typer==0.20.0
|
|
||||||
typing_extensions==4.15.0
|
|
||||||
urllib3==2.5.0
|
|
||||||
wavedrom==2.0.3.post3
|
|
||||||
wcwidth==0.2.14
|
|
||||||
yarg==0.1.10
|
|
||||||
Reference in New Issue
Block a user