2.1 #2

Merged
rune merged 10 commits from 2.1 into main 2026-02-03 09:02:44 +01:00
36 changed files with 11576 additions and 2485 deletions

15
.gitignore vendored
View File

@@ -22,8 +22,12 @@ Pipfile.lock # Consider if you want to include or exclude
._*
*~.nib
*~.xib
README.md.old
oai.zip
# Claude Code local settings
.claude/
# Added by author
*.zip
.note
diagnose.py
*.log
@@ -34,3 +38,10 @@ compiled/
images/oai-iOS-Default-1024x1024@1x.png
images/oai.icon/
b0.sh
*.bak
*.old
*.sh
*.back
requirements.txt
system_prompt.txt
CLAUDE*

411
README.md
View File

@@ -1,232 +1,301 @@
# oAI - OpenRouter AI Chat
# oAI - OpenRouter AI Chat Client
A terminal-based chat interface for OpenRouter API with conversation management, cost tracking, and rich formatting.
## Description
oAI is a command-line chat application that provides an interactive interface to OpenRouter's AI models. It features conversation persistence, file attachments, export capabilities, and detailed session metrics.
A powerful, extensible terminal-based chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases.
## Features
- Interactive chat with multiple AI models via OpenRouter
- Model selection with search functionality
- Conversation save/load/export (Markdown, JSON, HTML)
- File attachment support (code files and images)
- Session cost tracking and credit monitoring
- Rich terminal formatting with syntax highlighting
- Persistent command history
- Configurable system prompts and token limits
- SQLite-based configuration and conversation storage
### Core Features
- 🤖 Interactive chat with 300+ AI models via OpenRouter
- 🔍 Model selection with search and filtering
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
- 📎 File attachments (images, PDFs, code files)
- 💰 Real-time cost tracking and credit monitoring
- 🎨 Rich terminal UI with syntax highlighting
- 📝 Persistent command history with search (Ctrl+R)
- 🌐 Online mode (web search capabilities)
- 🧠 Conversation memory toggle
### MCP Integration
- 🔧 **File Mode**: AI can read, search, and list local files
- Automatic .gitignore filtering
- Virtual environment exclusion
- Large file handling (auto-truncates >50KB)
- ✍️ **Write Mode**: AI can modify files with permission
- Create, edit, delete files
- Move, copy, organize files
- Always requires explicit opt-in
- 🗄️ **Database Mode**: AI can query SQLite databases
- Read-only access (safe)
- Schema inspection
- Full SQL query support
## Requirements
- Python 3.7 or higher
- OpenRouter API key (get one at https://openrouter.ai)
## Screenshot (<span style="font-size:0.8em;">from version 1.0</span>)
[<img src="https://gitlab.pm/rune/oai/raw/branch/main/images/screenshot_01.png">](https://gitlab.pm/rune/oai/src/branch/main/README.md)
Screenshot of `/help` screen.
- Python 3.10-3.13
- OpenRouter API key ([get one here](https://openrouter.ai))
## Installation
### 1. Install Dependencies
Use the included `requirements.txt` file to install the dependencies:
### Option 1: Install from Source (Recommended)
```bash
pip install -r requirements.txt
# Clone the repository
git clone https://gitlab.pm/rune/oai.git
cd oai
# Install with pip
pip install -e .
```
### 2. Make the Script Executable
### Option 2: Pre-built Binary (macOS/Linux)
Download from [Releases](https://gitlab.pm/rune/oai/releases):
- **macOS (Apple Silicon)**: `oai_v2.1.0_mac_arm64.zip`
- **Linux (x86_64)**: `oai_v2.1.0_linux_x86_64.zip`
```bash
chmod +x oai.py
```
### 3. Copy to PATH
Copy the script to a directory in your `$PATH` environment variable. Common locations include:
```bash
# Option 1: System-wide (requires sudo)
sudo cp oai.py /usr/local/bin/oai
# Option 2: User-local (recommended)
# Extract and install
unzip oai_v2.1.0_*.zip
mkdir -p ~/.local/bin
cp oai.py ~/.local/bin/oai
mv oai ~/.local/bin/
# Add to PATH if not already (add to ~/.bashrc or ~/.zshrc)
# macOS only: Remove quarantine and approve
xattr -cr ~/.local/bin/oai
# Then right-click oai in Finder → Open With → Terminal → Click "Open"
```
### Add to PATH
```bash
# Add to ~/.zshrc or ~/.bashrc
export PATH="$HOME/.local/bin:$PATH"
```
### 4. Verify Installation
## Quick Start
```bash
oai
# Start the chat client
oai chat
# Or with options
oai chat --model gpt-4o --mcp
```
### 5. Alternative Installation (for *nix systems)
If you have issues with the above method you can add an alias in your `.bashrc`, `.zshrc` etc.
```bash
alias oai='python3 <path to your file>'
```
On first run, you will be prompted to enter your OpenRouter API key.
### 6. Use Binaries
You can also just download the supplied binary for either Mac wit Mx (M1, M2 etc) `oai_mac_arm64.zip` and follow [#3](https://gitlab.pm/rune/oai#3-copy-to-path). Or download for Linux (64bit) `oai_linux_x86_64.zip` and also follow [#3](https://gitlab.pm/rune/oai#3-copy-to-path).
## Usage
### Starting the Application
```bash
oai
```
On first run, you'll be prompted for your OpenRouter API key.
### Basic Commands
```
/help Show all available commands
/model Select an AI model
/config api Set OpenRouter API key
exit Quit the application
```bash
# In the chat interface:
/model # Select AI model
/help # Show all commands
/mcp on # Enable file/database access
/stats # View session statistics
exit # Quit
```
### Configuration
## MCP (Model Context Protocol)
All configuration is stored in `~/.config/oai/`:
- `oai_config.db` - SQLite database for settings and conversations
- `oai.log` - Application log file
- `history.txt` - Command history
MCP allows the AI to interact with your local files and databases.
### Common Workflows
### File Access
**Select a Model:**
```
/model
```bash
/mcp on # Enable MCP
/mcp add ~/Projects # Grant access to folder
/mcp list # View allowed folders
# Now ask the AI:
"List all Python files in Projects"
"Read and explain main.py"
"Search for files containing 'TODO'"
```
**Paste from clipboard:**
Paste and send content to model
```
/paste
### Write Mode
```bash
/mcp write on # Enable file modifications
# AI can now:
"Create a new file called utils.py"
"Edit config.json and update the API URL"
"Delete the old backup files" # Always asks for confirmation
```
Paste with prompt and send content to model
```
/paste Analyze this text
```
### Database Mode
**Start Chatting:**
```
You> Hello, how are you?
```
```bash
/mcp add db ~/app/data.db # Add database
/mcp db 1 # Switch to database mode
**Attach Files:**
# Ask the AI:
"Show all tables"
"Find users created this month"
"What's the schema for the orders table?"
```
You> Debug this code @/path/to/script.py
You> Analyze this image @/path/to/image.png
```
**Save Conversation:**
```
/save my_conversation
```
**Export to File:**
```
/export md notes.md
/export json backup.json
/export html report.html
```
**View Session Stats:**
```
/stats
/credits
```
**Prevous commands input:**
Use the up/down arrows to see earlier `/`commands and earlier input to model and `<enter>` to execute the same command or resend the same input.
## Command Reference
Use `/help` within the application for a complete command reference organized by category:
- Session Commands
- Model Commands
- Configuration
- Token & System
- Conversation Management
- Monitoring & Stats
- File Attachments
### Chat Commands
| Command | Description |
|---------|-------------|
| `/help [cmd]` | Show help |
| `/model [search]` | Select model |
| `/info [model]` | Model details |
| `/memory on\|off` | Toggle context |
| `/online on\|off` | Toggle web search |
| `/retry` | Resend last message |
| `/clear` | Clear screen |
## Configuration Options
### MCP Commands
| Command | Description |
|---------|-------------|
| `/mcp on\|off` | Enable/disable MCP |
| `/mcp status` | Show MCP status |
| `/mcp add <path>` | Add folder |
| `/mcp add db <path>` | Add database |
| `/mcp list` | List folders |
| `/mcp db list` | List databases |
| `/mcp db <n>` | Switch to database |
| `/mcp files` | Switch to file mode |
| `/mcp write on\|off` | Toggle write mode |
- API Key: `/config api`
- Base URL: `/config url`
- Streaming: `/config stream on|off`
- Default Model: `/config model`
- Cost Warning: `/config costwarning <amount>`
- Max Token Limit: `/config maxtoken <value>`
### Conversation Commands
| Command | Description |
|---------|-------------|
| `/save <name>` | Save conversation |
| `/load <name>` | Load conversation |
| `/list` | List saved conversations |
| `/delete <name>` | Delete conversation |
| `/export md\|json\|html <file>` | Export |
## File Support
### Configuration
| Command | Description |
|---------|-------------|
| `/config` | View settings |
| `/config api` | Set API key |
| `/config model <id>` | Set default model |
| `/config stream on\|off` | Toggle streaming |
| `/stats` | Session statistics |
| `/credits` | Check credits |
**Supported Code Extensions:**
.py, .js, .ts, .cs, .java, .c, .cpp, .h, .hpp, .rb, .ruby, .php, .swift, .kt, .kts, .go, .sh, .bat, .ps1, .R, .scala, .pl, .lua, .dart, .elm, .xml, .json, .yaml, .yml, .md, .txt
## CLI Options
**Image Support:**
Any image format with proper MIME type (PNG, JPEG, GIF, etc.)
```bash
oai chat [OPTIONS]
## Data Storage
Options:
-m, --model TEXT Model ID to use
-s, --system TEXT System prompt
-o, --online Enable online mode
--mcp Enable MCP server
--help Show help
```
- Configuration: `~/.config/oai/oai_config.db`
- Logs: `~/.config/oai/oai.log`
- History: `~/.config/oai/history.txt`
Other commands:
```bash
oai config [setting] [value] # Configure settings
oai version # Show version
oai credits # Check credits
```
## Configuration
Configuration is stored in `~/.config/oai/`:
| File | Purpose |
|------|---------|
| `oai_config.db` | Settings, conversations, MCP config |
| `oai.log` | Application logs |
| `history.txt` | Command history |
## Project Structure
```
oai/
├── oai/
│ ├── __init__.py
│ ├── __main__.py # Entry point for python -m oai
│ ├── cli.py # Main CLI interface
│ ├── constants.py # Configuration constants
│ ├── commands/ # Slash command handlers
│ ├── config/ # Settings and database
│ ├── core/ # Chat client and session
│ ├── mcp/ # MCP server and tools
│ ├── providers/ # AI provider abstraction
│ ├── ui/ # Terminal UI utilities
│ └── utils/ # Logging, export, etc.
├── pyproject.toml # Package configuration
├── build.sh # Binary build script
└── README.md
```
## Troubleshooting
### macOS Binary Issues
```bash
# Remove quarantine attribute
xattr -cr ~/.local/bin/oai
# Then in Finder: right-click oai → Open With → Terminal → Click "Open"
# After this, oai works from any terminal
```
### MCP Not Working
```bash
# Check if model supports function calling
/info # Look for "tools" in supported parameters
# Check MCP status
/mcp status
# View logs
tail -f ~/.config/oai/oai.log
```
### Import Errors
```bash
# Reinstall package
pip install -e . --force-reinstall
```
## Version History
### v2.1.0 (Current)
- 🏗️ Complete codebase refactoring to modular package structure
- 🔌 Extensible provider architecture for adding new AI providers
- 📦 Proper Python packaging with pyproject.toml
- ✨ MCP integration (file access, write mode, database queries)
- 🔧 Command registry pattern for slash commands
- 📊 Improved cost tracking and session statistics
### v1.9.x
- Single-file implementation
- Core chat functionality
- File attachments
- Conversation management
## License
MIT License
Copyright (c) 2024 Rune Olsen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Full license: https://opensource.org/licenses/MIT
MIT License - See [LICENSE](LICENSE) for details.
## Author
**Rune Olsen**
- Project: https://iurl.no/oai
- Repository: https://gitlab.pm/rune/oai
Blog: https://blog.rune.pm
## Contributing
## Version
1. Fork the repository
2. Create a feature branch
3. Submit a pull request
1.0
---
## Support
For issues, questions, or contributions, visit https://iurl.no/oai and create an issue.
**⭐ Star this project if you find it useful!**

2273
oai.py

File diff suppressed because it is too large Load Diff

26
oai/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
"""
oAI - OpenRouter AI Chat Client
A feature-rich terminal-based chat application that provides an interactive CLI
interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP)
integration for filesystem and database access.
Author: Rune
License: MIT
"""
__version__ = "2.1.0"
__author__ = "Rune"
__license__ = "MIT"
# Lazy imports to avoid circular dependencies and improve startup time
# Full imports are available via submodules:
# from oai.config import Settings, Database
# from oai.providers import OpenRouterProvider, AIProvider
# from oai.mcp import MCPManager
__all__ = [
"__version__",
"__author__",
"__license__",
]

8
oai/__main__.py Normal file
View File

@@ -0,0 +1,8 @@
"""
Entry point for running oAI as a module: python -m oai
"""
from oai.cli import main
if __name__ == "__main__":
main()

719
oai/cli.py Normal file
View File

@@ -0,0 +1,719 @@
"""
Main CLI entry point for oAI.
This module provides the command-line interface for the oAI chat application.
"""
import os
import sys
from pathlib import Path
from typing import Optional
import typer
from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.history import FileHistory
from rich.markdown import Markdown
from rich.panel import Panel
from oai import __version__
from oai.commands import register_all_commands, registry
from oai.commands.registry import CommandContext, CommandStatus
from oai.config.database import Database
from oai.config.settings import Settings
from oai.constants import (
APP_NAME,
APP_URL,
APP_VERSION,
CONFIG_DIR,
HISTORY_FILE,
VALID_COMMANDS,
)
from oai.core.client import AIClient
from oai.core.session import ChatSession
from oai.mcp.manager import MCPManager
from oai.providers.base import UsageStats
from oai.providers.openrouter import OpenRouterProvider
from oai.ui.console import (
clear_screen,
console,
display_panel,
print_error,
print_info,
print_success,
print_warning,
)
from oai.ui.tables import create_model_table, display_paginated_table
from oai.utils.logging import LoggingManager, get_logger
# Create Typer app
app = typer.Typer(
name="oai",
help=f"oAI - OpenRouter AI Chat Client\n\nVersion: {APP_VERSION}",
add_completion=False,
epilog="For more information, visit: " + APP_URL,
)
@app.callback(invoke_without_command=True)
def main_callback(
ctx: typer.Context,
version_flag: bool = typer.Option(
False,
"--version",
"-v",
help="Show version information",
is_flag=True,
),
) -> None:
"""Main callback to handle global options."""
# Show version with update check if --version flag
if version_flag:
version_info = check_for_updates(APP_VERSION)
console.print(version_info)
raise typer.Exit()
# Show version with update check when --help is requested
if "--help" in sys.argv or "-h" in sys.argv:
version_info = check_for_updates(APP_VERSION)
console.print(f"\n{version_info}\n")
# Continue to subcommand if provided
if ctx.invoked_subcommand is None:
return
def check_for_updates(current_version: str) -> str:
"""Check for available updates."""
import requests
from packaging import version as pkg_version
try:
response = requests.get(
"https://gitlab.pm/api/v1/repos/rune/oai/releases/latest",
headers={"Content-Type": "application/json"},
timeout=1.0,
)
response.raise_for_status()
data = response.json()
version_online = data.get("tag_name", "").lstrip("v")
if not version_online:
return f"[bold green]oAI version {current_version}[/]"
current = pkg_version.parse(current_version)
latest = pkg_version.parse(version_online)
if latest > current:
return (
f"[bold green]oAI version {current_version}[/] "
f"[bold red](Update available: {current_version}{version_online})[/]"
)
return f"[bold green]oAI version {current_version} (up to date)[/]"
except Exception:
return f"[bold green]oAI version {current_version}[/]"
def show_welcome(settings: Settings, version_info: str) -> None:
"""Display welcome message."""
console.print(Panel.fit(
f"{version_info}\n\n"
"[bold cyan]Commands:[/] /help for commands, /model to select model\n"
"[bold cyan]MCP:[/] /mcp on to enable file/database access\n"
"[bold cyan]Exit:[/] Type 'exit', 'quit', or 'bye'",
title=f"[bold green]Welcome to {APP_NAME}[/]",
border_style="green",
))
def select_model(client: AIClient, search_term: Optional[str] = None) -> Optional[dict]:
"""Display model selection interface."""
try:
models = client.provider.get_raw_models()
if not models:
print_error("No models available")
return None
# Filter by search term if provided
if search_term:
search_lower = search_term.lower()
models = [m for m in models if search_lower in m.get("id", "").lower()]
if not models:
print_error(f"No models found matching '{search_term}'")
return None
# Create and display table
table = create_model_table(models)
display_paginated_table(
table,
f"[bold green]Available Models ({len(models)})[/]",
)
# Prompt for selection
console.print("")
try:
choice = input("Enter model number (or press Enter to cancel): ").strip()
except (EOFError, KeyboardInterrupt):
return None
if not choice:
return None
try:
index = int(choice) - 1
if 0 <= index < len(models):
selected = models[index]
print_success(f"Selected model: {selected['id']}")
return selected
except ValueError:
pass
print_error("Invalid selection")
return None
except Exception as e:
print_error(f"Failed to fetch models: {e}")
return None
def run_chat_loop(
session: ChatSession,
prompt_session: PromptSession,
settings: Settings,
) -> None:
"""Run the main chat loop."""
logger = get_logger()
mcp_manager = session.mcp_manager
while True:
try:
# Build prompt prefix
prefix = "You> "
if mcp_manager and mcp_manager.enabled:
if mcp_manager.mode == "files":
if mcp_manager.write_enabled:
prefix = "[🔧✍️ MCP: Files+Write] You> "
else:
prefix = "[🔧 MCP: Files] You> "
elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None:
prefix = f"[🗄️ MCP: DB #{mcp_manager.selected_db_index + 1}] You> "
# Get user input
user_input = prompt_session.prompt(
prefix,
auto_suggest=AutoSuggestFromHistory(),
).strip()
if not user_input:
continue
# Handle escape sequence
if user_input.startswith("//"):
user_input = user_input[1:]
# Check for exit
if user_input.lower() in ["exit", "quit", "bye"]:
console.print(
f"\n[bold yellow]Goodbye![/]\n"
f"[dim]Session: {session.stats.total_tokens:,} tokens, "
f"${session.stats.total_cost:.4f}[/]"
)
logger.info(
f"Session ended. Messages: {session.stats.message_count}, "
f"Tokens: {session.stats.total_tokens}, "
f"Cost: ${session.stats.total_cost:.4f}"
)
return
# Check for unknown commands
if user_input.startswith("/"):
cmd_word = user_input.split()[0].lower()
if not registry.is_command(user_input):
# Check if it's a valid command prefix
is_valid = any(cmd_word.startswith(cmd) for cmd in VALID_COMMANDS)
if not is_valid:
print_error(f"Unknown command: {cmd_word}")
print_info("Type /help to see available commands.")
continue
# Try to execute as command
context = session.get_context()
result = registry.execute(user_input, context)
if result:
# Update session state from context
session.memory_enabled = context.memory_enabled
session.memory_start_index = context.memory_start_index
session.online_enabled = context.online_enabled
session.middle_out_enabled = context.middle_out_enabled
session.session_max_token = context.session_max_token
session.current_index = context.current_index
session.system_prompt = context.session_system_prompt
if result.status == CommandStatus.EXIT:
return
# Handle special results
if result.data:
# Retry - resend last prompt
if "retry_prompt" in result.data:
user_input = result.data["retry_prompt"]
# Fall through to send message
# Paste - send clipboard content
elif "paste_prompt" in result.data:
user_input = result.data["paste_prompt"]
# Fall through to send message
# Model selection
elif "show_model_selector" in result.data:
search = result.data.get("search", "")
model = select_model(session.client, search if search else None)
if model:
session.set_model(model)
# If this came from /config model, also save as default
if result.data.get("set_as_default"):
settings.set_default_model(model["id"])
print_success(f"Default model set to: {model['id']}")
continue
# Load conversation
elif "load_conversation" in result.data:
history = result.data.get("history", [])
session.history.clear()
from oai.core.session import HistoryEntry
for entry in history:
session.history.append(HistoryEntry(
prompt=entry.get("prompt", ""),
response=entry.get("response", ""),
prompt_tokens=entry.get("prompt_tokens", 0),
completion_tokens=entry.get("completion_tokens", 0),
msg_cost=entry.get("msg_cost", 0.0),
))
session.current_index = len(session.history) - 1
continue
else:
# Normal command completed
continue
else:
# Command completed with no special data
continue
# Ensure model is selected
if not session.selected_model:
print_warning("Please select a model first with /model")
continue
# Send message
stream = settings.stream_enabled
if mcp_manager and mcp_manager.enabled:
tools = session.get_mcp_tools()
if tools:
stream = False # Disable streaming with tools
if stream:
console.print(
"[bold green]Streaming response...[/] "
"[dim](Press Ctrl+C to cancel)[/]"
)
if session.online_enabled:
console.print("[dim cyan]🌐 Online mode active[/]")
console.print("")
try:
response_text, usage, response_time = session.send_message(
user_input,
stream=stream,
)
except Exception as e:
print_error(f"Error: {e}")
logger.error(f"Message error: {e}")
continue
if not response_text:
print_error("No response received")
continue
# Display non-streaming response
if not stream:
console.print()
display_panel(
Markdown(response_text),
title="[bold green]AI Response[/]",
border_style="green",
)
# Calculate cost and tokens
cost = 0.0
tokens = 0
estimated = False
if usage and (usage.prompt_tokens > 0 or usage.completion_tokens > 0):
tokens = usage.total_tokens
if usage.total_cost_usd:
cost = usage.total_cost_usd
else:
cost = session.client.estimate_cost(
session.selected_model["id"],
usage.prompt_tokens,
usage.completion_tokens,
)
else:
# Estimate tokens when usage not available (streaming fallback)
# Rough estimate: ~4 characters per token for English text
est_input_tokens = len(user_input) // 4 + 1
est_output_tokens = len(response_text) // 4 + 1
tokens = est_input_tokens + est_output_tokens
cost = session.client.estimate_cost(
session.selected_model["id"],
est_input_tokens,
est_output_tokens,
)
# Create estimated usage for session tracking
usage = UsageStats(
prompt_tokens=est_input_tokens,
completion_tokens=est_output_tokens,
total_tokens=tokens,
)
estimated = True
# Add to history
session.add_to_history(user_input, response_text, usage, cost)
# Display metrics
est_marker = "~" if estimated else ""
context_info = ""
if session.memory_enabled:
context_count = len(session.history) - session.memory_start_index
if context_count > 1:
context_info = f", Context: {context_count} msg(s)"
else:
context_info = ", Memory: OFF"
online_emoji = " 🌐" if session.online_enabled else ""
mcp_emoji = ""
if mcp_manager and mcp_manager.enabled:
if mcp_manager.mode == "files":
mcp_emoji = " 🔧"
elif mcp_manager.mode == "database":
mcp_emoji = " 🗄️"
console.print(
f"\n[dim blue]📊 {est_marker}{tokens} tokens | {est_marker}${cost:.4f} | {response_time:.2f}s"
f"{context_info}{online_emoji}{mcp_emoji} | "
f"Session: {est_marker}{session.stats.total_tokens:,} tokens | "
f"{est_marker}${session.stats.total_cost:.4f}[/]"
)
# Check warnings
warnings = session.check_warnings()
for warning in warnings:
print_warning(warning)
# Offer to copy
console.print("")
try:
from oai.ui.prompts import prompt_copy_response
prompt_copy_response(response_text)
except Exception:
pass
console.print("")
except KeyboardInterrupt:
console.print("\n[dim]Input cancelled[/]")
continue
except EOFError:
console.print("\n[bold yellow]Goodbye![/]")
return
@app.command()
def chat(
model: Optional[str] = typer.Option(
None,
"--model",
"-m",
help="Model ID to use",
),
system: Optional[str] = typer.Option(
None,
"--system",
"-s",
help="System prompt",
),
online: bool = typer.Option(
False,
"--online",
"-o",
help="Enable online mode",
),
mcp: bool = typer.Option(
False,
"--mcp",
help="Enable MCP server",
),
) -> None:
"""Start an interactive chat session."""
# Setup logging
logging_manager = LoggingManager()
logging_manager.setup()
logger = get_logger()
# Clear screen
clear_screen()
# Load settings
settings = Settings.load()
# Check API key
if not settings.api_key:
print_error("No API key configured")
print_info("Run: oai --config api to set your API key")
raise typer.Exit(1)
# Initialize client
try:
client = AIClient(
api_key=settings.api_key,
base_url=settings.base_url,
)
except Exception as e:
print_error(f"Failed to initialize client: {e}")
raise typer.Exit(1)
# Register commands
register_all_commands()
# Check for updates and show welcome
version_info = check_for_updates(APP_VERSION)
show_welcome(settings, version_info)
# Initialize MCP manager
mcp_manager = MCPManager()
if mcp:
result = mcp_manager.enable()
if result["success"]:
print_success("MCP enabled")
else:
print_warning(f"MCP: {result.get('error', 'Failed to enable')}")
# Create session
session = ChatSession(
client=client,
settings=settings,
mcp_manager=mcp_manager,
)
# Set system prompt
if system:
session.system_prompt = system
print_info(f"System prompt: {system}")
# Set online mode
if online:
session.online_enabled = True
print_info("Online mode enabled")
# Select model
if model:
raw_model = client.get_raw_model(model)
if raw_model:
session.set_model(raw_model)
else:
print_warning(f"Model '{model}' not found")
elif settings.default_model:
raw_model = client.get_raw_model(settings.default_model)
if raw_model:
session.set_model(raw_model)
else:
print_warning(f"Default model '{settings.default_model}' not available")
# Setup prompt session
HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True)
prompt_session = PromptSession(
history=FileHistory(str(HISTORY_FILE)),
)
# Run chat loop
run_chat_loop(session, prompt_session, settings)
@app.command()
def config(
setting: Optional[str] = typer.Argument(
None,
help="Setting to configure (api, url, model, system, stream, costwarning, maxtoken, online, log, loglevel)",
),
value: Optional[str] = typer.Argument(
None,
help="Value to set",
),
) -> None:
"""View or modify configuration settings."""
settings = Settings.load()
if not setting:
# Show all settings
from rich.table import Table
from oai.constants import DEFAULT_SYSTEM_PROMPT
table = Table("Setting", "Value", show_header=True, header_style="bold magenta")
table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set")
table.add_row("Base URL", settings.base_url)
table.add_row("Default Model", settings.default_model or "Not set")
# Show system prompt status
if settings.default_system_prompt is None:
system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
elif settings.default_system_prompt == "":
system_prompt_display = "[blank]"
else:
system_prompt_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt
table.add_row("System Prompt", system_prompt_display)
table.add_row("Streaming", "on" if settings.stream_enabled else "off")
table.add_row("Cost Warning", f"${settings.cost_warning_threshold:.4f}")
table.add_row("Max Tokens", str(settings.max_tokens))
table.add_row("Default Online", "on" if settings.default_online_mode else "off")
table.add_row("Log Level", settings.log_level)
display_panel(table, title="[bold green]Configuration[/]")
return
setting = setting.lower()
if setting == "api":
if value:
settings.set_api_key(value)
else:
from oai.ui.prompts import prompt_input
new_key = prompt_input("Enter API key", password=True)
if new_key:
settings.set_api_key(new_key)
print_success("API key updated")
elif setting == "url":
settings.set_base_url(value or "https://openrouter.ai/api/v1")
print_success(f"Base URL set to: {settings.base_url}")
elif setting == "model":
if value:
settings.set_default_model(value)
print_success(f"Default model set to: {value}")
else:
print_info(f"Current default model: {settings.default_model or 'Not set'}")
elif setting == "system":
from oai.constants import DEFAULT_SYSTEM_PROMPT
if value:
# Decode escape sequences like \n for newlines
value = value.encode().decode('unicode_escape')
settings.set_default_system_prompt(value)
if value:
print_success(f"Default system prompt set to: {value}")
else:
print_success("Default system prompt set to blank.")
else:
if settings.default_system_prompt is None:
print_info(f"Using hardcoded default: {DEFAULT_SYSTEM_PROMPT[:60]}...")
elif settings.default_system_prompt == "":
print_info("System prompt: [blank]")
else:
print_info(f"System prompt: {settings.default_system_prompt}")
elif setting == "stream":
if value and value.lower() in ["on", "off"]:
settings.set_stream_enabled(value.lower() == "on")
print_success(f"Streaming {'enabled' if settings.stream_enabled else 'disabled'}")
else:
print_info("Usage: oai config stream [on|off]")
elif setting == "costwarning":
if value:
try:
threshold = float(value)
settings.set_cost_warning_threshold(threshold)
print_success(f"Cost warning threshold set to: ${threshold:.4f}")
except ValueError:
print_error("Please enter a valid number")
else:
print_info(f"Current threshold: ${settings.cost_warning_threshold:.4f}")
elif setting == "maxtoken":
if value:
try:
max_tok = int(value)
settings.set_max_tokens(max_tok)
print_success(f"Max tokens set to: {max_tok}")
except ValueError:
print_error("Please enter a valid number")
else:
print_info(f"Current max tokens: {settings.max_tokens}")
elif setting == "online":
if value and value.lower() in ["on", "off"]:
settings.set_default_online_mode(value.lower() == "on")
print_success(f"Default online mode {'enabled' if settings.default_online_mode else 'disabled'}")
else:
print_info("Usage: oai config online [on|off]")
elif setting == "loglevel":
valid_levels = ["debug", "info", "warning", "error", "critical"]
if value and value.lower() in valid_levels:
settings.set_log_level(value.lower())
print_success(f"Log level set to: {value.lower()}")
else:
print_info(f"Valid levels: {', '.join(valid_levels)}")
else:
print_error(f"Unknown setting: {setting}")
@app.command()
def version() -> None:
"""Show version information."""
version_info = check_for_updates(APP_VERSION)
console.print(version_info)
@app.command()
def credits() -> None:
"""Check account credits."""
settings = Settings.load()
if not settings.api_key:
print_error("No API key configured")
raise typer.Exit(1)
client = AIClient(api_key=settings.api_key, base_url=settings.base_url)
credits_data = client.get_credits()
if not credits_data:
print_error("Failed to fetch credits")
raise typer.Exit(1)
from rich.table import Table
table = Table("Metric", "Value", show_header=True, header_style="bold magenta")
table.add_row("Total Credits", credits_data.get("total_credits_formatted", "N/A"))
table.add_row("Used Credits", credits_data.get("used_credits_formatted", "N/A"))
table.add_row("Credits Left", credits_data.get("credits_left_formatted", "N/A"))
display_panel(table, title="[bold green]Account Credits[/]")
def main() -> None:
"""Main entry point."""
# Default to 'chat' command if no arguments provided
if len(sys.argv) == 1:
sys.argv.append("chat")
app()
if __name__ == "__main__":
main()

24
oai/commands/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
"""
Command system for oAI.
This module provides a command registry and handler system
for processing slash commands in the chat interface.
"""
from oai.commands.registry import (
Command,
CommandRegistry,
CommandContext,
CommandResult,
registry,
)
from oai.commands.handlers import register_all_commands
__all__ = [
"Command",
"CommandRegistry",
"CommandContext",
"CommandResult",
"registry",
"register_all_commands",
]

1441
oai/commands/handlers.py Normal file

File diff suppressed because it is too large Load Diff

381
oai/commands/registry.py Normal file
View File

@@ -0,0 +1,381 @@
"""
Command registry for oAI.
This module defines the command system infrastructure including
the Command base class, CommandContext for state, and CommandRegistry
for managing available commands.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
from oai.utils.logging import get_logger
if TYPE_CHECKING:
from oai.config.settings import Settings
from oai.providers.base import AIProvider, ModelInfo
from oai.mcp.manager import MCPManager
class CommandStatus(str, Enum):
"""Status of command execution."""
SUCCESS = "success"
ERROR = "error"
CONTINUE = "continue" # Continue to next handler
EXIT = "exit" # Exit the application
@dataclass
class CommandResult:
"""
Result of a command execution.
Attributes:
status: Execution status
message: Optional message to display
data: Optional data payload
should_continue: Whether to continue the main loop
"""
status: CommandStatus = CommandStatus.SUCCESS
message: Optional[str] = None
data: Optional[Any] = None
should_continue: bool = True
@classmethod
def success(cls, message: Optional[str] = None, data: Any = None) -> "CommandResult":
"""Create a success result."""
return cls(status=CommandStatus.SUCCESS, message=message, data=data)
@classmethod
def error(cls, message: str) -> "CommandResult":
"""Create an error result."""
return cls(status=CommandStatus.ERROR, message=message)
@classmethod
def exit(cls, message: Optional[str] = None) -> "CommandResult":
"""Create an exit result."""
return cls(status=CommandStatus.EXIT, message=message, should_continue=False)
@dataclass
class CommandContext:
"""
Context object providing state to command handlers.
Contains all the session state needed by commands including
settings, provider, conversation history, and MCP manager.
Attributes:
settings: Application settings
provider: AI provider instance
mcp_manager: MCP manager instance
selected_model: Currently selected model
session_history: Conversation history
session_system_prompt: Current system prompt
memory_enabled: Whether memory is enabled
online_enabled: Whether online mode is enabled
session_tokens: Session token counts
session_cost: Session cost total
"""
settings: Optional["Settings"] = None
provider: Optional["AIProvider"] = None
mcp_manager: Optional["MCPManager"] = None
selected_model: Optional["ModelInfo"] = None
selected_model_raw: Optional[Dict[str, Any]] = None
session_history: List[Dict[str, Any]] = field(default_factory=list)
session_system_prompt: str = ""
memory_enabled: bool = True
memory_start_index: int = 0
online_enabled: bool = False
middle_out_enabled: bool = False
session_max_token: int = 0
total_input_tokens: int = 0
total_output_tokens: int = 0
total_cost: float = 0.0
message_count: int = 0
current_index: int = 0
@dataclass
class CommandHelp:
"""
Help information for a command.
Attributes:
description: Brief description
usage: Usage syntax
examples: List of (description, example) tuples
notes: Additional notes
aliases: Command aliases
"""
description: str
usage: str = ""
examples: List[tuple] = field(default_factory=list)
notes: str = ""
aliases: List[str] = field(default_factory=list)
class Command(ABC):
"""
Abstract base class for all commands.
Commands implement the execute method to handle their logic.
They can also provide help information and aliases.
"""
@property
@abstractmethod
def name(self) -> str:
"""Get the primary command name (e.g., '/help')."""
pass
@property
def aliases(self) -> List[str]:
"""Get command aliases (e.g., ['/h'] for help)."""
return []
@property
@abstractmethod
def help(self) -> CommandHelp:
"""Get command help information."""
pass
@abstractmethod
def execute(self, args: str, context: CommandContext) -> CommandResult:
"""
Execute the command.
Args:
args: Arguments passed to the command
context: Command execution context
Returns:
CommandResult indicating success/failure
"""
pass
def matches(self, input_text: str) -> bool:
"""
Check if this command matches the input.
Args:
input_text: User input text
Returns:
True if this command should handle the input
"""
input_lower = input_text.lower()
cmd_word = input_lower.split()[0] if input_lower.split() else ""
# Check primary name
if cmd_word == self.name.lower():
return True
# Check aliases
for alias in self.aliases:
if cmd_word == alias.lower():
return True
return False
def get_args(self, input_text: str) -> str:
"""
Extract arguments from the input text.
Args:
input_text: Full user input
Returns:
Arguments portion of the input
"""
parts = input_text.split(maxsplit=1)
return parts[1] if len(parts) > 1 else ""
class CommandRegistry:
"""
Registry for managing available commands.
Provides registration, lookup, and execution of commands.
"""
def __init__(self):
"""Initialize an empty command registry."""
self._commands: Dict[str, Command] = {}
self._aliases: Dict[str, str] = {}
self.logger = get_logger()
def register(self, command: Command) -> None:
"""
Register a command.
Args:
command: Command instance to register
Raises:
ValueError: If command name already registered
"""
name = command.name.lower()
if name in self._commands:
raise ValueError(f"Command '{name}' already registered")
self._commands[name] = command
# Register aliases
for alias in command.aliases:
alias_lower = alias.lower()
if alias_lower in self._aliases:
self.logger.warning(
f"Alias '{alias}' already registered, overwriting"
)
self._aliases[alias_lower] = name
self.logger.debug(f"Registered command: {name}")
def register_function(
self,
name: str,
handler: Callable[[str, CommandContext], CommandResult],
description: str,
usage: str = "",
aliases: Optional[List[str]] = None,
examples: Optional[List[tuple]] = None,
notes: str = "",
) -> None:
"""
Register a function-based command.
Convenience method for simple commands that don't need
a full Command class.
Args:
name: Command name (e.g., '/help')
handler: Function to execute
description: Help description
usage: Usage syntax
aliases: Command aliases
examples: Example usages
notes: Additional notes
"""
aliases = aliases or []
examples = examples or []
class FunctionCommand(Command):
@property
def name(self) -> str:
return name
@property
def aliases(self) -> List[str]:
return aliases
@property
def help(self) -> CommandHelp:
return CommandHelp(
description=description,
usage=usage,
examples=examples,
notes=notes,
aliases=aliases,
)
def execute(self, args: str, context: CommandContext) -> CommandResult:
return handler(args, context)
self.register(FunctionCommand())
def get(self, name: str) -> Optional[Command]:
"""
Get a command by name or alias.
Args:
name: Command name or alias
Returns:
Command instance or None if not found
"""
name_lower = name.lower()
# Check direct match
if name_lower in self._commands:
return self._commands[name_lower]
# Check aliases
if name_lower in self._aliases:
return self._commands[self._aliases[name_lower]]
return None
def find(self, input_text: str) -> Optional[Command]:
"""
Find a command that matches the input.
Args:
input_text: User input text
Returns:
Matching Command or None
"""
cmd_word = input_text.lower().split()[0] if input_text.split() else ""
return self.get(cmd_word)
def execute(self, input_text: str, context: CommandContext) -> Optional[CommandResult]:
"""
Execute a command matching the input.
Args:
input_text: User input text
context: Execution context
Returns:
CommandResult or None if no matching command
"""
command = self.find(input_text)
if command:
args = command.get_args(input_text)
self.logger.debug(f"Executing command: {command.name} with args: {args}")
return command.execute(args, context)
return None
def is_command(self, input_text: str) -> bool:
"""
Check if input is a valid command.
Args:
input_text: User input text
Returns:
True if input matches a registered command
"""
return self.find(input_text) is not None
def list_commands(self) -> List[Command]:
"""
Get all registered commands.
Returns:
List of Command instances
"""
return list(self._commands.values())
def get_all_names(self) -> List[str]:
"""
Get all command names and aliases.
Returns:
List of command names including aliases
"""
names = list(self._commands.keys())
names.extend(self._aliases.keys())
return sorted(set(names))
# Global registry instance
registry = CommandRegistry()

11
oai/config/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
"""
Configuration management for oAI.
This package handles all configuration persistence, settings management,
and database operations for the application.
"""
from oai.config.settings import Settings
from oai.config.database import Database
__all__ = ["Settings", "Database"]

472
oai/config/database.py Normal file
View File

@@ -0,0 +1,472 @@
"""
Database persistence layer for oAI.
This module provides a clean abstraction for SQLite operations including
configuration storage, conversation persistence, and MCP statistics tracking.
All database operations are centralized here for maintainability.
"""
import sqlite3
import json
import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any
from contextlib import contextmanager
from oai.constants import DATABASE_FILE, CONFIG_DIR
class Database:
"""
SQLite database manager for oAI.
Handles all database operations including:
- Configuration key-value storage
- Conversation session persistence
- MCP configuration and statistics
- Database registrations for MCP
Uses context managers for safe connection handling and supports
automatic table creation on first use.
"""
def __init__(self, db_path: Optional[Path] = None):
"""
Initialize the database manager.
Args:
db_path: Optional custom database path. Defaults to standard location.
"""
self.db_path = db_path or DATABASE_FILE
self._ensure_directories()
self._ensure_tables()
def _ensure_directories(self) -> None:
"""Ensure the configuration directory exists."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
@contextmanager
def _connection(self):
"""
Context manager for database connections.
Yields:
sqlite3.Connection: Active database connection
Example:
with self._connection() as conn:
conn.execute("SELECT * FROM config")
"""
conn = sqlite3.connect(str(self.db_path))
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _ensure_tables(self) -> None:
"""Create all required tables if they don't exist."""
with self._connection() as conn:
# Main configuration table
conn.execute("""
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
)
""")
# Conversation sessions table
conn.execute("""
CREATE TABLE IF NOT EXISTS conversation_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
timestamp TEXT NOT NULL,
data TEXT NOT NULL
)
""")
# MCP configuration table
conn.execute("""
CREATE TABLE IF NOT EXISTS mcp_config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
)
""")
# MCP statistics table
conn.execute("""
CREATE TABLE IF NOT EXISTS mcp_stats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
tool_name TEXT NOT NULL,
folder TEXT,
success INTEGER NOT NULL,
error_message TEXT
)
""")
# MCP databases table
conn.execute("""
CREATE TABLE IF NOT EXISTS mcp_databases (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
size INTEGER,
tables TEXT,
added_timestamp TEXT NOT NULL
)
""")
# =========================================================================
# CONFIGURATION METHODS
# =========================================================================
def get_config(self, key: str) -> Optional[str]:
"""
Retrieve a configuration value by key.
Args:
key: The configuration key to retrieve
Returns:
The configuration value, or None if not found
"""
with self._connection() as conn:
cursor = conn.execute(
"SELECT value FROM config WHERE key = ?",
(key,)
)
result = cursor.fetchone()
return result[0] if result else None
def set_config(self, key: str, value: str) -> None:
"""
Set a configuration value.
Args:
key: The configuration key
value: The value to store
"""
with self._connection() as conn:
conn.execute(
"INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)",
(key, value)
)
def delete_config(self, key: str) -> bool:
"""
Delete a configuration value.
Args:
key: The configuration key to delete
Returns:
True if a row was deleted, False otherwise
"""
with self._connection() as conn:
cursor = conn.execute(
"DELETE FROM config WHERE key = ?",
(key,)
)
return cursor.rowcount > 0
def get_all_config(self) -> Dict[str, str]:
"""
Retrieve all configuration values.
Returns:
Dictionary of all key-value pairs
"""
with self._connection() as conn:
cursor = conn.execute("SELECT key, value FROM config")
return dict(cursor.fetchall())
# =========================================================================
# MCP CONFIGURATION METHODS
# =========================================================================
def get_mcp_config(self, key: str) -> Optional[str]:
"""
Retrieve an MCP configuration value.
Args:
key: The MCP configuration key
Returns:
The configuration value, or None if not found
"""
with self._connection() as conn:
cursor = conn.execute(
"SELECT value FROM mcp_config WHERE key = ?",
(key,)
)
result = cursor.fetchone()
return result[0] if result else None
def set_mcp_config(self, key: str, value: str) -> None:
"""
Set an MCP configuration value.
Args:
key: The MCP configuration key
value: The value to store
"""
with self._connection() as conn:
conn.execute(
"INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)",
(key, value)
)
# =========================================================================
# MCP STATISTICS METHODS
# =========================================================================
def log_mcp_stat(
self,
tool_name: str,
folder: Optional[str],
success: bool,
error_message: Optional[str] = None
) -> None:
"""
Log an MCP tool usage event.
Args:
tool_name: Name of the MCP tool that was called
folder: The folder path involved (if any)
success: Whether the call succeeded
error_message: Error message if the call failed
"""
timestamp = datetime.datetime.now().isoformat()
with self._connection() as conn:
conn.execute(
"""INSERT INTO mcp_stats
(timestamp, tool_name, folder, success, error_message)
VALUES (?, ?, ?, ?, ?)""",
(timestamp, tool_name, folder, 1 if success else 0, error_message)
)
def get_mcp_stats(self) -> Dict[str, Any]:
"""
Get aggregated MCP usage statistics.
Returns:
Dictionary containing usage statistics:
- total_calls: Total number of tool calls
- reads: Number of file reads
- lists: Number of directory listings
- searches: Number of file searches
- db_inspects: Number of database inspections
- db_searches: Number of database searches
- db_queries: Number of database queries
- last_used: Timestamp of last usage
"""
with self._connection() as conn:
cursor = conn.execute("""
SELECT
COUNT(*) as total_calls,
SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads,
SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists,
SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches,
SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects,
SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches,
SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries,
MAX(timestamp) as last_used
FROM mcp_stats
""")
row = cursor.fetchone()
return {
"total_calls": row[0] or 0,
"reads": row[1] or 0,
"lists": row[2] or 0,
"searches": row[3] or 0,
"db_inspects": row[4] or 0,
"db_searches": row[5] or 0,
"db_queries": row[6] or 0,
"last_used": row[7],
}
# =========================================================================
# MCP DATABASE REGISTRY METHODS
# =========================================================================
def add_mcp_database(self, db_info: Dict[str, Any]) -> int:
"""
Register a database for MCP access.
Args:
db_info: Dictionary containing:
- path: Database file path
- name: Display name
- size: File size in bytes
- tables: List of table names
- added: Timestamp when added
Returns:
The database ID
"""
with self._connection() as conn:
conn.execute(
"""INSERT INTO mcp_databases
(path, name, size, tables, added_timestamp)
VALUES (?, ?, ?, ?, ?)""",
(
db_info["path"],
db_info["name"],
db_info["size"],
json.dumps(db_info["tables"]),
db_info["added"]
)
)
cursor = conn.execute(
"SELECT id FROM mcp_databases WHERE path = ?",
(db_info["path"],)
)
return cursor.fetchone()[0]
def remove_mcp_database(self, db_path: str) -> bool:
"""
Remove a database from the MCP registry.
Args:
db_path: Path to the database file
Returns:
True if a row was deleted, False otherwise
"""
with self._connection() as conn:
cursor = conn.execute(
"DELETE FROM mcp_databases WHERE path = ?",
(db_path,)
)
return cursor.rowcount > 0
def get_mcp_databases(self) -> List[Dict[str, Any]]:
"""
Retrieve all registered MCP databases.
Returns:
List of database information dictionaries
"""
with self._connection() as conn:
cursor = conn.execute(
"""SELECT id, path, name, size, tables, added_timestamp
FROM mcp_databases ORDER BY id"""
)
databases = []
for row in cursor.fetchall():
tables_list = json.loads(row[4]) if row[4] else []
databases.append({
"id": row[0],
"path": row[1],
"name": row[2],
"size": row[3],
"tables": tables_list,
"added": row[5],
})
return databases
# =========================================================================
# CONVERSATION METHODS
# =========================================================================
def save_conversation(self, name: str, data: List[Dict[str, str]]) -> None:
"""
Save a conversation session.
Args:
name: Name/identifier for the conversation
data: List of message dictionaries with 'prompt' and 'response' keys
"""
timestamp = datetime.datetime.now().isoformat()
data_json = json.dumps(data)
with self._connection() as conn:
conn.execute(
"""INSERT INTO conversation_sessions
(name, timestamp, data) VALUES (?, ?, ?)""",
(name, timestamp, data_json)
)
def load_conversation(self, name: str) -> Optional[List[Dict[str, str]]]:
"""
Load a conversation by name.
Args:
name: Name of the conversation to load
Returns:
List of messages, or None if not found
"""
with self._connection() as conn:
cursor = conn.execute(
"""SELECT data FROM conversation_sessions
WHERE name = ?
ORDER BY timestamp DESC LIMIT 1""",
(name,)
)
result = cursor.fetchone()
if result:
return json.loads(result[0])
return None
def delete_conversation(self, name: str) -> int:
"""
Delete a conversation by name.
Args:
name: Name of the conversation to delete
Returns:
Number of rows deleted
"""
with self._connection() as conn:
cursor = conn.execute(
"DELETE FROM conversation_sessions WHERE name = ?",
(name,)
)
return cursor.rowcount
def list_conversations(self) -> List[Dict[str, Any]]:
"""
List all saved conversations.
Returns:
List of conversation summaries with name, timestamp, and message_count
"""
with self._connection() as conn:
cursor = conn.execute("""
SELECT name, MAX(timestamp) as last_saved, data
FROM conversation_sessions
GROUP BY name
ORDER BY last_saved DESC
""")
conversations = []
for row in cursor.fetchall():
name, timestamp, data_json = row
data = json.loads(data_json)
conversations.append({
"name": name,
"timestamp": timestamp,
"message_count": len(data),
})
return conversations
# Global database instance for convenience
_db: Optional[Database] = None
def get_database() -> Database:
"""
Get the global database instance.
Returns:
The shared Database instance
"""
global _db
if _db is None:
_db = Database()
return _db

361
oai/config/settings.py Normal file
View File

@@ -0,0 +1,361 @@
"""
Settings management for oAI.
This module provides a centralized settings class that handles all application
configuration with type safety, validation, and persistence.
"""
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path
from oai.constants import (
DEFAULT_BASE_URL,
DEFAULT_STREAM_ENABLED,
DEFAULT_MAX_TOKENS,
DEFAULT_ONLINE_MODE,
DEFAULT_COST_WARNING_THRESHOLD,
DEFAULT_LOG_MAX_SIZE_MB,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_LEVEL,
DEFAULT_SYSTEM_PROMPT,
VALID_LOG_LEVELS,
)
from oai.config.database import get_database
@dataclass
class Settings:
"""
Application settings with persistence support.
This class provides a clean interface for managing all configuration
options. Settings are automatically loaded from the database on
initialization and can be persisted back.
Attributes:
api_key: OpenRouter API key
base_url: API base URL
default_model: Default model ID to use
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
stream_enabled: Whether to stream responses
max_tokens: Maximum tokens per request
cost_warning_threshold: Alert threshold for message cost
default_online_mode: Whether online mode is enabled by default
log_max_size_mb: Maximum log file size in MB
log_backup_count: Number of log file backups to keep
log_level: Logging level (debug/info/warning/error/critical)
"""
api_key: Optional[str] = None
base_url: str = DEFAULT_BASE_URL
default_model: Optional[str] = None
default_system_prompt: Optional[str] = None
stream_enabled: bool = DEFAULT_STREAM_ENABLED
max_tokens: int = DEFAULT_MAX_TOKENS
cost_warning_threshold: float = DEFAULT_COST_WARNING_THRESHOLD
default_online_mode: bool = DEFAULT_ONLINE_MODE
log_max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
log_backup_count: int = DEFAULT_LOG_BACKUP_COUNT
log_level: str = DEFAULT_LOG_LEVEL
@property
def effective_system_prompt(self) -> str:
"""
Get the effective system prompt to use.
Returns:
The custom prompt if set, hardcoded default if None, or blank if explicitly set to ""
"""
if self.default_system_prompt is None:
return DEFAULT_SYSTEM_PROMPT
return self.default_system_prompt
def __post_init__(self):
"""Validate settings after initialization."""
self._validate()
def _validate(self) -> None:
"""Validate all settings values."""
# Validate log level
if self.log_level.lower() not in VALID_LOG_LEVELS:
raise ValueError(
f"Invalid log level: {self.log_level}. "
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
)
# Validate numeric bounds
if self.max_tokens < 1:
raise ValueError("max_tokens must be at least 1")
if self.cost_warning_threshold < 0:
raise ValueError("cost_warning_threshold must be non-negative")
if self.log_max_size_mb < 1:
raise ValueError("log_max_size_mb must be at least 1")
if self.log_backup_count < 0:
raise ValueError("log_backup_count must be non-negative")
@classmethod
def load(cls) -> "Settings":
"""
Load settings from the database.
Returns:
Settings instance with values from database
"""
db = get_database()
# Helper to safely parse boolean
def parse_bool(value: Optional[str], default: bool) -> bool:
if value is None:
return default
return value.lower() in ("on", "true", "1", "yes")
# Helper to safely parse int
def parse_int(value: Optional[str], default: int) -> int:
if value is None:
return default
try:
return int(value)
except ValueError:
return default
# Helper to safely parse float
def parse_float(value: Optional[str], default: float) -> float:
if value is None:
return default
try:
return float(value)
except ValueError:
return default
# Get system prompt from DB: None means not set (use default), "" means explicitly blank
system_prompt_value = db.get_config("default_system_prompt")
return cls(
api_key=db.get_config("api_key"),
base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
default_model=db.get_config("default_model"),
default_system_prompt=system_prompt_value,
stream_enabled=parse_bool(
db.get_config("stream_enabled"),
DEFAULT_STREAM_ENABLED
),
max_tokens=parse_int(
db.get_config("max_token"),
DEFAULT_MAX_TOKENS
),
cost_warning_threshold=parse_float(
db.get_config("cost_warning_threshold"),
DEFAULT_COST_WARNING_THRESHOLD
),
default_online_mode=parse_bool(
db.get_config("default_online_mode"),
DEFAULT_ONLINE_MODE
),
log_max_size_mb=parse_int(
db.get_config("log_max_size_mb"),
DEFAULT_LOG_MAX_SIZE_MB
),
log_backup_count=parse_int(
db.get_config("log_backup_count"),
DEFAULT_LOG_BACKUP_COUNT
),
log_level=db.get_config("log_level") or DEFAULT_LOG_LEVEL,
)
def save(self) -> None:
"""Persist all settings to the database."""
db = get_database()
# Only save API key if it exists
if self.api_key:
db.set_config("api_key", self.api_key)
db.set_config("base_url", self.base_url)
if self.default_model:
db.set_config("default_model", self.default_model)
# Save system prompt: None means not set (don't save), otherwise save the value (even if "")
if self.default_system_prompt is not None:
db.set_config("default_system_prompt", self.default_system_prompt)
db.set_config("stream_enabled", "on" if self.stream_enabled else "off")
db.set_config("max_token", str(self.max_tokens))
db.set_config("cost_warning_threshold", str(self.cost_warning_threshold))
db.set_config("default_online_mode", "on" if self.default_online_mode else "off")
db.set_config("log_max_size_mb", str(self.log_max_size_mb))
db.set_config("log_backup_count", str(self.log_backup_count))
db.set_config("log_level", self.log_level)
def set_api_key(self, api_key: str) -> None:
"""
Set and persist the API key.
Args:
api_key: The new API key
"""
self.api_key = api_key.strip()
get_database().set_config("api_key", self.api_key)
def set_base_url(self, url: str) -> None:
"""
Set and persist the base URL.
Args:
url: The new base URL
"""
self.base_url = url.strip()
get_database().set_config("base_url", self.base_url)
def set_default_model(self, model_id: str) -> None:
"""
Set and persist the default model.
Args:
model_id: The model ID to set as default
"""
self.default_model = model_id
get_database().set_config("default_model", model_id)
def set_default_system_prompt(self, prompt: str) -> None:
"""
Set and persist the default system prompt.
Args:
prompt: The system prompt to use for all new sessions.
Empty string "" means blank prompt (no system message).
"""
self.default_system_prompt = prompt
get_database().set_config("default_system_prompt", prompt)
def clear_default_system_prompt(self) -> None:
"""
Clear the custom system prompt and revert to hardcoded default.
This removes the custom prompt from the database, causing the
application to use the built-in DEFAULT_SYSTEM_PROMPT.
"""
self.default_system_prompt = None
# Remove from database to indicate "not set"
db = get_database()
with db._connection() as conn:
conn.execute("DELETE FROM config WHERE key = ?", ("default_system_prompt",))
conn.commit()
def set_stream_enabled(self, enabled: bool) -> None:
"""
Set and persist the streaming preference.
Args:
enabled: Whether to enable streaming
"""
self.stream_enabled = enabled
get_database().set_config("stream_enabled", "on" if enabled else "off")
def set_max_tokens(self, max_tokens: int) -> None:
"""
Set and persist the maximum tokens.
Args:
max_tokens: Maximum number of tokens
Raises:
ValueError: If max_tokens is less than 1
"""
if max_tokens < 1:
raise ValueError("max_tokens must be at least 1")
self.max_tokens = max_tokens
get_database().set_config("max_token", str(max_tokens))
def set_cost_warning_threshold(self, threshold: float) -> None:
"""
Set and persist the cost warning threshold.
Args:
threshold: Cost threshold in USD
Raises:
ValueError: If threshold is negative
"""
if threshold < 0:
raise ValueError("cost_warning_threshold must be non-negative")
self.cost_warning_threshold = threshold
get_database().set_config("cost_warning_threshold", str(threshold))
def set_default_online_mode(self, enabled: bool) -> None:
"""
Set and persist the default online mode.
Args:
enabled: Whether online mode should be enabled by default
"""
self.default_online_mode = enabled
get_database().set_config("default_online_mode", "on" if enabled else "off")
def set_log_level(self, level: str) -> None:
"""
Set and persist the log level.
Args:
level: The log level (debug/info/warning/error/critical)
Raises:
ValueError: If level is not valid
"""
level_lower = level.lower()
if level_lower not in VALID_LOG_LEVELS:
raise ValueError(
f"Invalid log level: {level}. "
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
)
self.log_level = level_lower
get_database().set_config("log_level", level_lower)
def set_log_max_size(self, size_mb: int) -> None:
"""
Set and persist the maximum log file size.
Args:
size_mb: Maximum size in megabytes
Raises:
ValueError: If size_mb is less than 1
"""
if size_mb < 1:
raise ValueError("log_max_size_mb must be at least 1")
# Cap at 100 MB for safety
self.log_max_size_mb = min(size_mb, 100)
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
# Global settings instance
_settings: Optional[Settings] = None
def get_settings() -> Settings:
"""
Get the global settings instance.
Returns:
The shared Settings instance, loading from database if needed
"""
global _settings
if _settings is None:
_settings = Settings.load()
return _settings
def reload_settings() -> Settings:
"""
Force reload settings from the database.
Returns:
Fresh Settings instance
"""
global _settings
_settings = Settings.load()
return _settings

448
oai/constants.py Normal file
View File

@@ -0,0 +1,448 @@
"""
Application-wide constants for oAI.
This module contains all configuration constants, default values, and static
definitions used throughout the application. Centralizing these values makes
the codebase easier to maintain and configure.
"""
from pathlib import Path
from typing import Set, Dict, Any
import logging
# =============================================================================
# APPLICATION METADATA
# =============================================================================
APP_NAME = "oAI"
APP_VERSION = "2.1.0"
APP_URL = "https://iurl.no/oai"
APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration"
# =============================================================================
# FILE PATHS
# =============================================================================
HOME_DIR = Path.home()
CONFIG_DIR = HOME_DIR / ".config" / "oai"
CACHE_DIR = HOME_DIR / ".cache" / "oai"
HISTORY_FILE = CONFIG_DIR / "history.txt"
DATABASE_FILE = CONFIG_DIR / "oai_config.db"
LOG_FILE = CONFIG_DIR / "oai.log"
# =============================================================================
# API CONFIGURATION
# =============================================================================
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
DEFAULT_STREAM_ENABLED = True
DEFAULT_MAX_TOKENS = 100_000
DEFAULT_ONLINE_MODE = False
# =============================================================================
# DEFAULT SYSTEM PROMPT
# =============================================================================
DEFAULT_SYSTEM_PROMPT = (
"You are a knowledgeable and helpful AI assistant. Provide clear, accurate, "
"and well-structured responses. Be concise yet thorough. When uncertain about "
"something, acknowledge your limitations. For technical topics, include relevant "
"details and examples when helpful."
)
# =============================================================================
# PRICING DEFAULTS (per million tokens)
# =============================================================================
DEFAULT_INPUT_PRICE = 3.0
DEFAULT_OUTPUT_PRICE = 15.0
MODEL_PRICING: Dict[str, float] = {
"input": DEFAULT_INPUT_PRICE,
"output": DEFAULT_OUTPUT_PRICE,
}
# =============================================================================
# CREDIT ALERTS
# =============================================================================
LOW_CREDIT_RATIO = 0.1 # Alert when credits < 10% of total
LOW_CREDIT_AMOUNT = 1.0 # Alert when credits < $1.00
DEFAULT_COST_WARNING_THRESHOLD = 0.01 # Alert when single message cost exceeds this
COST_WARNING_THRESHOLD = DEFAULT_COST_WARNING_THRESHOLD # Alias for convenience
# =============================================================================
# LOGGING CONFIGURATION
# =============================================================================
DEFAULT_LOG_MAX_SIZE_MB = 10
DEFAULT_LOG_BACKUP_COUNT = 2
DEFAULT_LOG_LEVEL = "info"
VALID_LOG_LEVELS: Dict[str, int] = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
# =============================================================================
# FILE HANDLING
# =============================================================================
# Maximum file size for reading (10 MB)
MAX_FILE_SIZE = 10 * 1024 * 1024
# Content truncation threshold (50 KB)
CONTENT_TRUNCATION_THRESHOLD = 50 * 1024
# Maximum items in directory listing
MAX_LIST_ITEMS = 1000
# Supported code file extensions for syntax highlighting
SUPPORTED_CODE_EXTENSIONS: Set[str] = {
".py", ".js", ".ts", ".cs", ".java", ".c", ".cpp", ".h", ".hpp",
".rb", ".ruby", ".php", ".swift", ".kt", ".kts", ".go",
".sh", ".bat", ".ps1", ".r", ".scala", ".pl", ".lua", ".dart",
".elm", ".xml", ".json", ".yaml", ".yml", ".md", ".txt",
}
# All allowed file extensions for attachment
ALLOWED_FILE_EXTENSIONS: Set[str] = {
# Code files
".py", ".js", ".ts", ".jsx", ".tsx", ".vue", ".java", ".c", ".cpp", ".cc", ".cxx",
".h", ".hpp", ".hxx", ".rb", ".go", ".rs", ".swift", ".kt", ".kts", ".php",
".sh", ".bash", ".zsh", ".fish", ".bat", ".cmd", ".ps1",
# Data files
".json", ".csv", ".yaml", ".yml", ".toml", ".xml", ".sql", ".db", ".sqlite", ".sqlite3",
# Documents
".txt", ".md", ".log", ".conf", ".cfg", ".ini", ".env", ".properties",
# Images
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg", ".ico",
# Archives
".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz",
# Config files
".lock", ".gitignore", ".dockerignore", ".editorconfig", ".eslintrc",
".prettierrc", ".babelrc", ".nvmrc", ".npmrc",
# Binary/Compiled
".pyc", ".pyo", ".pyd", ".so", ".dll", ".dylib", ".exe", ".app",
".dmg", ".pkg", ".deb", ".rpm", ".apk", ".ipa",
# ML/AI
".pkl", ".pickle", ".joblib", ".npy", ".npz", ".safetensors", ".onnx",
".pt", ".pth", ".ckpt", ".pb", ".tflite", ".mlmodel", ".coreml", ".rknn",
# Data formats
".wasm", ".proto", ".graphql", ".graphqls", ".grpc", ".avro", ".parquet",
".orc", ".feather", ".arrow", ".hdf5", ".h5", ".mat", ".rdata", ".rds",
# Other
".pdf", ".class", ".jar", ".war",
}
# =============================================================================
# SECURITY CONFIGURATION
# =============================================================================
# System directories that should never be accessed
SYSTEM_DIRS_BLACKLIST: Set[str] = {
# macOS
"/System", "/Library", "/private", "/usr", "/bin", "/sbin",
# Linux
"/boot", "/dev", "/proc", "/sys", "/root",
# Windows
"C:\\Windows", "C:\\Program Files", "C:\\Program Files (x86)",
}
# Directories to skip during file operations
SKIP_DIRECTORIES: Set[str] = {
# Python virtual environments
".venv", "venv", "env", "virtualenv",
"site-packages", "dist-packages",
# Python caches
"__pycache__", ".pytest_cache", ".mypy_cache",
# JavaScript/Node
"node_modules",
# Version control
".git", ".svn",
# IDEs
".idea", ".vscode",
# Build directories
"build", "dist", "eggs", ".eggs",
}
# =============================================================================
# DATABASE QUERIES - SQL SAFETY
# =============================================================================
# Maximum query execution timeout (seconds)
MAX_QUERY_TIMEOUT = 5
# Maximum rows returned from queries
MAX_QUERY_RESULTS = 1000
# Default rows per query
DEFAULT_QUERY_LIMIT = 100
# Keywords that are blocked in database queries
DANGEROUS_SQL_KEYWORDS: Set[str] = {
"INSERT", "UPDATE", "DELETE", "DROP", "CREATE",
"ALTER", "TRUNCATE", "REPLACE", "ATTACH", "DETACH",
"PRAGMA", "VACUUM", "REINDEX",
}
# =============================================================================
# MCP CONFIGURATION
# =============================================================================
# Maximum tool call iterations per request
MAX_TOOL_LOOPS = 5
# =============================================================================
# VALID COMMANDS
# =============================================================================
VALID_COMMANDS: Set[str] = {
"/retry", "/online", "/memory", "/paste", "/export", "/save", "/load",
"/delete", "/list", "/prev", "/next", "/stats", "/middleout", "/reset",
"/info", "/model", "/maxtoken", "/system", "/config", "/credits", "/clear",
"/cl", "/help", "/mcp",
}
# =============================================================================
# COMMAND HELP DATABASE
# =============================================================================
COMMAND_HELP: Dict[str, Dict[str, Any]] = {
"/clear": {
"aliases": ["/cl"],
"description": "Clear the terminal screen for a clean interface.",
"usage": "/clear\n/cl",
"examples": [
("Clear screen", "/clear"),
("Using short alias", "/cl"),
],
"notes": "You can also use the keyboard shortcut Ctrl+L.",
},
"/help": {
"description": "Display help information for commands.",
"usage": "/help [command|topic]",
"examples": [
("Show all commands", "/help"),
("Get help for a specific command", "/help /model"),
("Get detailed MCP help", "/help mcp"),
],
"notes": "Use /help without arguments to see the full command list.",
},
"mcp": {
"description": "Complete guide to MCP (Model Context Protocol).",
"usage": "See detailed examples below",
"examples": [],
"notes": """
MCP (Model Context Protocol) gives your AI assistant direct access to:
• Local files and folders (read, search, list)
• SQLite databases (inspect, search, query)
FILE MODE (default):
/mcp on Start MCP server
/mcp add ~/Documents Grant access to folder
/mcp list View all allowed folders
DATABASE MODE:
/mcp add db ~/app/data.db Add specific database
/mcp db list View all databases
/mcp db 1 Work with database #1
/mcp files Switch back to file mode
WRITE MODE (optional):
/mcp write on Enable file modifications
/mcp write off Disable write mode (back to read-only)
For command-specific help: /help /mcp
""",
},
"/mcp": {
"description": "Manage MCP for local file access and SQLite database querying.",
"usage": "/mcp <command> [args]",
"examples": [
("Enable MCP server", "/mcp on"),
("Disable MCP server", "/mcp off"),
("Show MCP status", "/mcp status"),
("", ""),
("━━━ FILE MODE ━━━", ""),
("Add folder for file access", "/mcp add ~/Documents"),
("Remove folder", "/mcp remove ~/Desktop"),
("List allowed folders", "/mcp list"),
("Enable write mode", "/mcp write on"),
("", ""),
("━━━ DATABASE MODE ━━━", ""),
("Add SQLite database", "/mcp add db ~/app/data.db"),
("List all databases", "/mcp db list"),
("Switch to database #1", "/mcp db 1"),
("Switch back to file mode", "/mcp files"),
],
"notes": "MCP allows AI to read local files and query SQLite databases.",
},
"/memory": {
"description": "Toggle conversation memory.",
"usage": "/memory [on|off]",
"examples": [
("Check current memory status", "/memory"),
("Enable conversation memory", "/memory on"),
("Disable memory (save costs)", "/memory off"),
],
"notes": "Memory is ON by default. Disabling saves tokens.",
},
"/online": {
"description": "Enable or disable online mode (web search).",
"usage": "/online [on|off]",
"examples": [
("Check online mode status", "/online"),
("Enable web search", "/online on"),
("Disable web search", "/online off"),
],
"notes": "Not all models support online mode.",
},
"/paste": {
"description": "Paste plain text from clipboard and send to the AI.",
"usage": "/paste [prompt]",
"examples": [
("Paste clipboard content", "/paste"),
("Paste with a question", "/paste Explain this code"),
],
"notes": "Only plain text is supported.",
},
"/retry": {
"description": "Resend the last prompt from conversation history.",
"usage": "/retry",
"examples": [("Retry last message", "/retry")],
"notes": "Requires at least one message in history.",
},
"/next": {
"description": "View the next response in conversation history.",
"usage": "/next",
"examples": [("Navigate to next response", "/next")],
"notes": "Use /prev to go backward.",
},
"/prev": {
"description": "View the previous response in conversation history.",
"usage": "/prev",
"examples": [("Navigate to previous response", "/prev")],
"notes": "Use /next to go forward.",
},
"/reset": {
"description": "Clear conversation history and reset system prompt.",
"usage": "/reset",
"examples": [("Reset conversation", "/reset")],
"notes": "Requires confirmation.",
},
"/info": {
"description": "Display detailed information about a model.",
"usage": "/info [model_id]",
"examples": [
("Show current model info", "/info"),
("Show specific model info", "/info gpt-4o"),
],
"notes": "Shows pricing, capabilities, and context length.",
},
"/model": {
"description": "Select or change the AI model.",
"usage": "/model [search_term]",
"examples": [
("List all models", "/model"),
("Search for GPT models", "/model gpt"),
("Search for Claude models", "/model claude"),
],
"notes": "Models are numbered for easy selection.",
},
"/config": {
"description": "View or modify application configuration.",
"usage": "/config [setting] [value]",
"examples": [
("View all settings", "/config"),
("Set API key", "/config api"),
("Set default model", "/config model"),
("Set system prompt", "/config system You are a helpful assistant"),
("Enable streaming", "/config stream on"),
],
"notes": "Available: api, url, model, system, stream, costwarning, maxtoken, online, loglevel.",
},
"/maxtoken": {
"description": "Set a temporary session token limit.",
"usage": "/maxtoken [value]",
"examples": [
("View current session limit", "/maxtoken"),
("Set session limit to 2000", "/maxtoken 2000"),
],
"notes": "Cannot exceed stored max token limit.",
},
"/system": {
"description": "Set or clear the session-level system prompt.",
"usage": "/system [prompt|clear|default <prompt>]",
"examples": [
("View current system prompt", "/system"),
("Set as Python expert", "/system You are a Python expert"),
("Multiline with newlines", r"/system You are an expert.\nBe clear and concise."),
("Save as default", "/system default You are a helpful assistant"),
("Revert to default", "/system clear"),
("Blank prompt", '/system ""'),
],
"notes": r"Use \n for newlines. /system clear reverts to hardcoded default.",
},
"/save": {
"description": "Save the current conversation history.",
"usage": "/save <name>",
"examples": [("Save conversation", "/save my_chat")],
"notes": "Saved conversations can be loaded later with /load.",
},
"/load": {
"description": "Load a saved conversation.",
"usage": "/load <name|number>",
"examples": [
("Load by name", "/load my_chat"),
("Load by number from /list", "/load 3"),
],
"notes": "Use /list to see numbered conversations.",
},
"/delete": {
"description": "Delete a saved conversation.",
"usage": "/delete <name|number>",
"examples": [("Delete by name", "/delete my_chat")],
"notes": "Requires confirmation. Cannot be undone.",
},
"/list": {
"description": "List all saved conversations.",
"usage": "/list",
"examples": [("Show saved conversations", "/list")],
"notes": "Conversations are numbered for use with /load and /delete.",
},
"/export": {
"description": "Export the current conversation to a file.",
"usage": "/export <format> <filename>",
"examples": [
("Export as Markdown", "/export md notes.md"),
("Export as JSON", "/export json conversation.json"),
("Export as HTML", "/export html report.html"),
],
"notes": "Available formats: md, json, html.",
},
"/stats": {
"description": "Display session statistics.",
"usage": "/stats",
"examples": [("View session statistics", "/stats")],
"notes": "Shows tokens, costs, and credits.",
},
"/credits": {
"description": "Display your OpenRouter account credits.",
"usage": "/credits",
"examples": [("Check credits", "/credits")],
"notes": "Shows total, used, and remaining credits.",
},
"/middleout": {
"description": "Toggle middle-out transform for long prompts.",
"usage": "/middleout [on|off]",
"examples": [
("Check status", "/middleout"),
("Enable compression", "/middleout on"),
],
"notes": "Compresses prompts exceeding context size.",
},
}

14
oai/core/__init__.py Normal file
View File

@@ -0,0 +1,14 @@
"""
Core functionality for oAI.
This module provides the main session management and AI client
classes that power the chat application.
"""
from oai.core.session import ChatSession
from oai.core.client import AIClient
__all__ = [
"ChatSession",
"AIClient",
]

422
oai/core/client.py Normal file
View File

@@ -0,0 +1,422 @@
"""
AI Client for oAI.
This module provides a high-level client for interacting with AI models
through the provider abstraction layer.
"""
import asyncio
import json
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ModelInfo,
StreamChunk,
ToolCall,
UsageStats,
)
from oai.providers.openrouter import OpenRouterProvider
from oai.utils.logging import get_logger
class AIClient:
"""
High-level AI client for chat interactions.
Provides a simplified interface for sending chat requests,
handling streaming, and managing tool calls.
Attributes:
provider: The underlying AI provider
default_model: Default model ID to use
http_headers: Custom HTTP headers for requests
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
provider_class: type = OpenRouterProvider,
app_name: str = APP_NAME,
app_url: str = APP_URL,
):
"""
Initialize the AI client.
Args:
api_key: API key for authentication
base_url: Optional custom base URL
provider_class: Provider class to use (default: OpenRouterProvider)
app_name: Application name for headers
app_url: Application URL for headers
"""
self.provider: AIProvider = provider_class(
api_key=api_key,
base_url=base_url,
app_name=app_name,
app_url=app_url,
)
self.default_model: Optional[str] = None
self.logger = get_logger()
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
Get available models.
Args:
filter_text_only: Whether to exclude video-only models
Returns:
List of ModelInfo objects
"""
return self.provider.list_models(filter_text_only=filter_text_only)
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: Model identifier
Returns:
ModelInfo or None if not found
"""
return self.provider.get_model(model_id)
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for provider-specific fields.
Args:
model_id: Model identifier
Returns:
Raw model dictionary or None
"""
if hasattr(self.provider, "get_raw_model"):
return self.provider.get_raw_model(model_id)
return None
def chat(
self,
messages: List[Dict[str, Any]],
model: Optional[str] = None,
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
system_prompt: Optional[str] = None,
online: bool = False,
transforms: Optional[List[str]] = None,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send a chat request.
Args:
messages: List of message dictionaries
model: Model ID (uses default if not specified)
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature
tools: Tool definitions for function calling
tool_choice: Tool selection mode
system_prompt: System prompt to prepend
online: Whether to enable online mode
transforms: List of transforms (e.g., ["middle-out"])
Returns:
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
Raises:
ValueError: If no model specified and no default set
"""
model_id = model or self.default_model
if not model_id:
raise ValueError("No model specified and no default set")
# Apply online mode suffix
if online and hasattr(self.provider, "get_effective_model_id"):
model_id = self.provider.get_effective_model_id(model_id, True)
# Convert dict messages to ChatMessage objects
chat_messages = []
# Add system prompt if provided
if system_prompt:
chat_messages.append(ChatMessage(role="system", content=system_prompt))
# Convert message dicts
for msg in messages:
# Convert tool_calls dicts to ToolCall objects if present
tool_calls_data = msg.get("tool_calls")
tool_calls_obj = None
if tool_calls_data:
from oai.providers.base import ToolCall, ToolFunction
tool_calls_obj = []
for tc in tool_calls_data:
# Handle both ToolCall objects and dicts
if isinstance(tc, ToolCall):
tool_calls_obj.append(tc)
elif isinstance(tc, dict):
func_data = tc.get("function", {})
tool_calls_obj.append(
ToolCall(
id=tc.get("id", ""),
type=tc.get("type", "function"),
function=ToolFunction(
name=func_data.get("name", ""),
arguments=func_data.get("arguments", "{}"),
),
)
)
chat_messages.append(
ChatMessage(
role=msg.get("role", "user"),
content=msg.get("content"),
tool_calls=tool_calls_obj,
tool_call_id=msg.get("tool_call_id"),
)
)
self.logger.debug(
f"Sending chat request: model={model_id}, "
f"messages={len(chat_messages)}, stream={stream}"
)
return self.provider.chat(
model=model_id,
messages=chat_messages,
stream=stream,
max_tokens=max_tokens,
temperature=temperature,
tools=tools,
tool_choice=tool_choice,
transforms=transforms,
)
def chat_with_tools(
self,
messages: List[Dict[str, Any]],
tools: List[Dict[str, Any]],
tool_executor: Callable[[str, Dict[str, Any]], Dict[str, Any]],
model: Optional[str] = None,
max_loops: int = 5,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
on_tool_call: Optional[Callable[[ToolCall], None]] = None,
on_tool_result: Optional[Callable[[str, Dict[str, Any]], None]] = None,
) -> ChatResponse:
"""
Send a chat request with automatic tool call handling.
Executes tool calls returned by the model and continues
the conversation until no more tool calls are requested.
Args:
messages: Initial messages
tools: Tool definitions
tool_executor: Function to execute tool calls
model: Model ID
max_loops: Maximum tool call iterations
max_tokens: Maximum response tokens
system_prompt: System prompt
on_tool_call: Callback when tool is called
on_tool_result: Callback when tool returns result
Returns:
Final ChatResponse after all tool calls complete
"""
model_id = model or self.default_model
if not model_id:
raise ValueError("No model specified and no default set")
# Build initial messages
chat_messages = []
if system_prompt:
chat_messages.append({"role": "system", "content": system_prompt})
chat_messages.extend(messages)
loop_count = 0
current_response: Optional[ChatResponse] = None
while loop_count < max_loops:
# Send request
response = self.chat(
messages=chat_messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
)
if not isinstance(response, ChatResponse):
raise ValueError("Expected non-streaming response")
current_response = response
# Check for tool calls
tool_calls = response.tool_calls
if not tool_calls:
break
self.logger.info(f"Model requested {len(tool_calls)} tool call(s)")
# Process each tool call
tool_results = []
for tc in tool_calls:
if on_tool_call:
on_tool_call(tc)
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse tool arguments: {e}")
result = {"error": f"Invalid arguments: {e}"}
else:
result = tool_executor(tc.function.name, args)
if on_tool_result:
on_tool_result(tc.function.name, result)
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps(result),
})
# Add assistant message with tool calls
assistant_msg = {
"role": "assistant",
"content": response.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
}
chat_messages.append(assistant_msg)
chat_messages.extend(tool_results)
loop_count += 1
if loop_count >= max_loops:
self.logger.warning(f"Reached max tool call loops ({max_loops})")
return current_response
def stream_chat(
self,
messages: List[Dict[str, Any]],
model: Optional[str] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
online: bool = False,
on_chunk: Optional[Callable[[StreamChunk], None]] = None,
) -> tuple[str, Optional[UsageStats]]:
"""
Stream a chat response and collect the full text.
Args:
messages: Chat messages
model: Model ID
max_tokens: Maximum tokens
system_prompt: System prompt
online: Online mode
on_chunk: Optional callback for each chunk
Returns:
Tuple of (full_response_text, usage_stats)
"""
response = self.chat(
messages=messages,
model=model,
stream=True,
max_tokens=max_tokens,
system_prompt=system_prompt,
online=online,
)
if isinstance(response, ChatResponse):
# Not actually streaming
return response.content or "", response.usage
full_text = ""
usage: Optional[UsageStats] = None
for chunk in response:
if chunk.error:
self.logger.error(f"Stream error: {chunk.error}")
break
if chunk.delta_content:
full_text += chunk.delta_content
if on_chunk:
on_chunk(chunk)
if chunk.usage:
usage = chunk.usage
return full_text, usage
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credit information.
Returns:
Credit info dict or None if unavailable
"""
return self.provider.get_credits()
def estimate_cost(
self,
model_id: str,
input_tokens: int,
output_tokens: int,
) -> float:
"""
Estimate cost for a completion.
Args:
model_id: Model ID
input_tokens: Number of input tokens
output_tokens: Number of output tokens
Returns:
Estimated cost in USD
"""
if hasattr(self.provider, "estimate_cost"):
return self.provider.estimate_cost(model_id, input_tokens, output_tokens)
# Fallback to default pricing
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
return input_cost + output_cost
def set_default_model(self, model_id: str) -> None:
"""
Set the default model.
Args:
model_id: Model ID to use as default
"""
self.default_model = model_id
self.logger.info(f"Default model set to: {model_id}")
def clear_cache(self) -> None:
"""Clear the provider's model cache."""
if hasattr(self.provider, "clear_cache"):
self.provider.clear_cache()

659
oai/core/session.py Normal file
View File

@@ -0,0 +1,659 @@
"""
Chat session management for oAI.
This module provides the ChatSession class that manages an interactive
chat session including history, state, and message handling.
"""
import asyncio
import json
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from rich.live import Live
from rich.markdown import Markdown
from oai.commands.registry import CommandContext, CommandResult, registry
from oai.config.database import Database
from oai.config.settings import Settings
from oai.constants import (
COST_WARNING_THRESHOLD,
LOW_CREDIT_AMOUNT,
LOW_CREDIT_RATIO,
)
from oai.core.client import AIClient
from oai.mcp.manager import MCPManager
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
from oai.ui.console import (
console,
display_markdown,
display_panel,
print_error,
print_info,
print_success,
print_warning,
)
from oai.ui.prompts import prompt_copy_response
from oai.utils.logging import get_logger
@dataclass
class SessionStats:
"""
Statistics for the current session.
Tracks tokens, costs, and message counts.
"""
total_input_tokens: int = 0
total_output_tokens: int = 0
total_cost: float = 0.0
message_count: int = 0
@property
def total_tokens(self) -> int:
"""Get total token count."""
return self.total_input_tokens + self.total_output_tokens
def add_usage(self, usage: Optional[UsageStats], cost: float = 0.0) -> None:
"""
Add usage stats from a response.
Args:
usage: Usage statistics
cost: Cost if not in usage
"""
if usage:
self.total_input_tokens += usage.prompt_tokens
self.total_output_tokens += usage.completion_tokens
if usage.total_cost_usd:
self.total_cost += usage.total_cost_usd
else:
self.total_cost += cost
else:
self.total_cost += cost
self.message_count += 1
@dataclass
class HistoryEntry:
"""
A single entry in the conversation history.
Stores the user prompt, assistant response, and metrics.
"""
prompt: str
response: str
prompt_tokens: int = 0
completion_tokens: int = 0
msg_cost: float = 0.0
timestamp: Optional[float] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"prompt": self.prompt,
"response": self.response,
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"msg_cost": self.msg_cost,
}
class ChatSession:
"""
Manages an interactive chat session.
Handles conversation history, state management, command processing,
and communication with the AI client.
Attributes:
client: AI client for API requests
settings: Application settings
mcp_manager: MCP manager for file/database access
history: Conversation history
stats: Session statistics
"""
def __init__(
self,
client: AIClient,
settings: Settings,
mcp_manager: Optional[MCPManager] = None,
):
"""
Initialize a chat session.
Args:
client: AI client instance
settings: Application settings
mcp_manager: Optional MCP manager
"""
self.client = client
self.settings = settings
self.mcp_manager = mcp_manager
self.db = Database()
self.history: List[HistoryEntry] = []
self.stats = SessionStats()
# Session state
self.system_prompt: str = settings.effective_system_prompt
self.memory_enabled: bool = True
self.memory_start_index: int = 0
self.online_enabled: bool = settings.default_online_mode
self.middle_out_enabled: bool = False
self.session_max_token: int = 0
self.current_index: int = 0
# Selected model
self.selected_model: Optional[Dict[str, Any]] = None
self.logger = get_logger()
def get_context(self) -> CommandContext:
"""
Get the current command context.
Returns:
CommandContext with current session state
"""
return CommandContext(
settings=self.settings,
provider=self.client.provider,
mcp_manager=self.mcp_manager,
selected_model_raw=self.selected_model,
session_history=[e.to_dict() for e in self.history],
session_system_prompt=self.system_prompt,
memory_enabled=self.memory_enabled,
memory_start_index=self.memory_start_index,
online_enabled=self.online_enabled,
middle_out_enabled=self.middle_out_enabled,
session_max_token=self.session_max_token,
total_input_tokens=self.stats.total_input_tokens,
total_output_tokens=self.stats.total_output_tokens,
total_cost=self.stats.total_cost,
message_count=self.stats.message_count,
current_index=self.current_index,
)
def set_model(self, model: Dict[str, Any]) -> None:
"""
Set the selected model.
Args:
model: Raw model dictionary
"""
self.selected_model = model
self.client.set_default_model(model["id"])
self.logger.info(f"Model selected: {model['id']}")
def build_api_messages(self, user_input: str) -> List[Dict[str, Any]]:
"""
Build the messages array for an API request.
Includes system prompt, history (if memory enabled), and current input.
Args:
user_input: Current user input
Returns:
List of message dictionaries
"""
messages = []
# Add system prompt
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
# Add database context if in database mode
if self.mcp_manager and self.mcp_manager.enabled:
if self.mcp_manager.mode == "database" and self.mcp_manager.selected_db_index is not None:
db = self.mcp_manager.databases[self.mcp_manager.selected_db_index]
db_context = (
f"You are connected to SQLite database: {db['name']}\n"
f"Available tables: {', '.join(db['tables'])}\n\n"
"Use inspect_database, search_database, or query_database tools. "
"All queries are read-only."
)
messages.append({"role": "system", "content": db_context})
# Add history if memory enabled
if self.memory_enabled:
for i in range(self.memory_start_index, len(self.history)):
entry = self.history[i]
messages.append({"role": "user", "content": entry.prompt})
messages.append({"role": "assistant", "content": entry.response})
# Add current message
messages.append({"role": "user", "content": user_input})
return messages
def get_mcp_tools(self) -> Optional[List[Dict[str, Any]]]:
"""
Get MCP tool definitions if available.
Returns:
List of tool schemas or None
"""
if not self.mcp_manager or not self.mcp_manager.enabled:
return None
if not self.selected_model:
return None
# Check if model supports tools
supported_params = self.selected_model.get("supported_parameters", [])
if "tools" not in supported_params and "functions" not in supported_params:
return None
return self.mcp_manager.get_tools_schema()
async def execute_tool(
self,
tool_name: str,
tool_args: Dict[str, Any],
) -> Dict[str, Any]:
"""
Execute an MCP tool.
Args:
tool_name: Name of the tool
tool_args: Tool arguments
Returns:
Tool execution result
"""
if not self.mcp_manager:
return {"error": "MCP not available"}
return await self.mcp_manager.call_tool(tool_name, **tool_args)
def send_message(
self,
user_input: str,
stream: bool = True,
on_stream_chunk: Optional[Callable[[str], None]] = None,
) -> Tuple[str, Optional[UsageStats], float]:
"""
Send a message and get a response.
Args:
user_input: User's input text
stream: Whether to stream the response
on_stream_chunk: Callback for stream chunks
Returns:
Tuple of (response_text, usage_stats, response_time)
"""
if not self.selected_model:
raise ValueError("No model selected")
start_time = time.time()
messages = self.build_api_messages(user_input)
# Get MCP tools
tools = self.get_mcp_tools()
if tools:
# Disable streaming when tools are present
stream = False
# Build request parameters
model_id = self.selected_model["id"]
if self.online_enabled:
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
transforms = ["middle-out"] if self.middle_out_enabled else None
max_tokens = None
if self.session_max_token > 0:
max_tokens = self.session_max_token
if tools:
# Use tool handling flow
response = self._send_with_tools(
messages=messages,
model_id=model_id,
tools=tools,
max_tokens=max_tokens,
transforms=transforms,
)
response_time = time.time() - start_time
return response.content or "", response.usage, response_time
elif stream:
# Use streaming flow
full_text, usage = self._stream_response(
messages=messages,
model_id=model_id,
max_tokens=max_tokens,
transforms=transforms,
on_chunk=on_stream_chunk,
)
response_time = time.time() - start_time
return full_text, usage, response_time
else:
# Non-streaming request
response = self.client.chat(
messages=messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
transforms=transforms,
)
response_time = time.time() - start_time
if isinstance(response, ChatResponse):
return response.content or "", response.usage, response_time
else:
return "", None, response_time
def _send_with_tools(
self,
messages: List[Dict[str, Any]],
model_id: str,
tools: List[Dict[str, Any]],
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
) -> ChatResponse:
"""
Send a request with tool call handling.
Args:
messages: API messages
model_id: Model ID
tools: Tool definitions
max_tokens: Max tokens
transforms: Transforms list
Returns:
Final ChatResponse
"""
max_loops = 5
loop_count = 0
api_messages = list(messages)
while loop_count < max_loops:
response = self.client.chat(
messages=api_messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
transforms=transforms,
)
if not isinstance(response, ChatResponse):
raise ValueError("Expected ChatResponse")
tool_calls = response.tool_calls
if not tool_calls:
return response
console.print(f"\n[dim yellow]🔧 AI requesting {len(tool_calls)} tool call(s)...[/]")
tool_results = []
for tc in tool_calls:
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse tool arguments: {e}")
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
})
continue
# Display tool call
args_display = ", ".join(
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
for k, v in args.items()
)
console.print(f"[dim cyan] → {tc.function.name}({args_display})[/]")
# Execute tool
result = asyncio.run(self.execute_tool(tc.function.name, args))
if "error" in result:
console.print(f"[dim red] ✗ Error: {result['error']}[/]")
else:
self._display_tool_success(tc.function.name, result)
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps(result),
})
# Add assistant message with tool calls
api_messages.append({
"role": "assistant",
"content": response.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
})
api_messages.extend(tool_results)
console.print("\n[dim cyan]💭 Processing tool results...[/]")
loop_count += 1
self.logger.warning(f"Reached max tool loops ({max_loops})")
console.print(f"[bold yellow]⚠️ Reached maximum tool calls ({max_loops})[/]")
return response
def _display_tool_success(self, tool_name: str, result: Dict[str, Any]) -> None:
"""Display a success message for a tool call."""
if tool_name == "search_files":
count = result.get("count", 0)
console.print(f"[dim green] ✓ Found {count} file(s)[/]")
elif tool_name == "read_file":
size = result.get("size", 0)
truncated = " (truncated)" if result.get("truncated") else ""
console.print(f"[dim green] ✓ Read {size} bytes{truncated}[/]")
elif tool_name == "list_directory":
count = result.get("count", 0)
console.print(f"[dim green] ✓ Listed {count} item(s)[/]")
elif tool_name == "inspect_database":
if "table" in result:
console.print(f"[dim green] ✓ Inspected table: {result['table']}[/]")
else:
console.print(f"[dim green] ✓ Inspected database ({result.get('table_count', 0)} tables)[/]")
elif tool_name == "search_database":
count = result.get("count", 0)
console.print(f"[dim green] ✓ Found {count} match(es)[/]")
elif tool_name == "query_database":
count = result.get("count", 0)
console.print(f"[dim green] ✓ Query returned {count} row(s)[/]")
else:
console.print("[dim green] ✓ Success[/]")
def _stream_response(
self,
messages: List[Dict[str, Any]],
model_id: str,
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
on_chunk: Optional[Callable[[str], None]] = None,
) -> Tuple[str, Optional[UsageStats]]:
"""
Stream a response with live display.
Args:
messages: API messages
model_id: Model ID
max_tokens: Max tokens
transforms: Transforms
on_chunk: Callback for chunks
Returns:
Tuple of (full_text, usage)
"""
response = self.client.chat(
messages=messages,
model=model_id,
stream=True,
max_tokens=max_tokens,
transforms=transforms,
)
if isinstance(response, ChatResponse):
return response.content or "", response.usage
full_text = ""
usage: Optional[UsageStats] = None
try:
with Live("", console=console, refresh_per_second=10) as live:
for chunk in response:
if chunk.error:
console.print(f"\n[bold red]Stream error: {chunk.error}[/]")
break
if chunk.delta_content:
full_text += chunk.delta_content
live.update(Markdown(full_text))
if on_chunk:
on_chunk(chunk.delta_content)
if chunk.usage:
usage = chunk.usage
except KeyboardInterrupt:
console.print("\n[bold yellow]⚠️ Streaming interrupted[/]")
return "", None
return full_text, usage
def add_to_history(
self,
prompt: str,
response: str,
usage: Optional[UsageStats] = None,
cost: float = 0.0,
) -> None:
"""
Add an exchange to the history.
Args:
prompt: User prompt
response: Assistant response
usage: Usage statistics
cost: Cost if not in usage
"""
entry = HistoryEntry(
prompt=prompt,
response=response,
prompt_tokens=usage.prompt_tokens if usage else 0,
completion_tokens=usage.completion_tokens if usage else 0,
msg_cost=usage.total_cost_usd if usage and usage.total_cost_usd else cost,
timestamp=time.time(),
)
self.history.append(entry)
self.current_index = len(self.history) - 1
self.stats.add_usage(usage, cost)
def save_conversation(self, name: str) -> bool:
"""
Save the current conversation.
Args:
name: Name for the saved conversation
Returns:
True if saved successfully
"""
if not self.history:
return False
data = [e.to_dict() for e in self.history]
self.db.save_conversation(name, data)
self.logger.info(f"Saved conversation: {name}")
return True
def load_conversation(self, name: str) -> bool:
"""
Load a saved conversation.
Args:
name: Name of the conversation to load
Returns:
True if loaded successfully
"""
data = self.db.load_conversation(name)
if not data:
return False
self.history.clear()
for entry_dict in data:
self.history.append(HistoryEntry(
prompt=entry_dict.get("prompt", ""),
response=entry_dict.get("response", ""),
prompt_tokens=entry_dict.get("prompt_tokens", 0),
completion_tokens=entry_dict.get("completion_tokens", 0),
msg_cost=entry_dict.get("msg_cost", 0.0),
))
self.current_index = len(self.history) - 1
self.memory_start_index = 0
self.stats = SessionStats() # Reset stats for loaded conversation
self.logger.info(f"Loaded conversation: {name}")
return True
def reset(self) -> None:
"""Reset the session state."""
self.history.clear()
self.stats = SessionStats()
self.system_prompt = ""
self.memory_start_index = 0
self.current_index = 0
self.logger.info("Session reset")
def check_warnings(self) -> List[str]:
"""
Check for cost and credit warnings.
Returns:
List of warning messages
"""
warnings = []
# Check last message cost
if self.history:
last_cost = self.history[-1].msg_cost
threshold = self.settings.cost_warning_threshold
if last_cost > threshold:
warnings.append(
f"High cost: ${last_cost:.4f} exceeds threshold ${threshold:.4f}"
)
# Check credits
credits = self.client.get_credits()
if credits:
left = credits.get("credits_left", 0)
total = credits.get("total_credits", 0)
if left < LOW_CREDIT_AMOUNT:
warnings.append(f"Low credits: ${left:.2f} remaining!")
elif total > 0 and left < total * LOW_CREDIT_RATIO:
warnings.append(f"Credits low: less than 10% remaining (${left:.2f})")
return warnings

28
oai/mcp/__init__.py Normal file
View File

@@ -0,0 +1,28 @@
"""
Model Context Protocol (MCP) integration for oAI.
This package provides filesystem and database access capabilities
through the MCP standard, allowing AI models to interact with
local files and SQLite databases safely.
Key components:
- MCPManager: High-level manager for MCP operations
- MCPFilesystemServer: Filesystem and database access implementation
- GitignoreParser: Pattern matching for .gitignore support
- SQLiteQueryValidator: Query safety validation
- CrossPlatformMCPConfig: OS-specific configuration
"""
from oai.mcp.manager import MCPManager
from oai.mcp.server import MCPFilesystemServer
from oai.mcp.gitignore import GitignoreParser
from oai.mcp.validators import SQLiteQueryValidator
from oai.mcp.platform import CrossPlatformMCPConfig
__all__ = [
"MCPManager",
"MCPFilesystemServer",
"GitignoreParser",
"SQLiteQueryValidator",
"CrossPlatformMCPConfig",
]

166
oai/mcp/gitignore.py Normal file
View File

@@ -0,0 +1,166 @@
"""
Gitignore pattern parsing for oAI MCP.
This module implements .gitignore pattern matching to filter files
during MCP filesystem operations.
"""
import fnmatch
from pathlib import Path
from typing import List, Tuple
from oai.utils.logging import get_logger
class GitignoreParser:
"""
Parse and apply .gitignore patterns.
Supports standard gitignore syntax including:
- Wildcards (*) and double wildcards (**)
- Directory-only patterns (ending with /)
- Negation patterns (starting with !)
- Comments (lines starting with #)
Patterns are applied relative to the directory containing
the .gitignore file.
"""
def __init__(self):
"""Initialize an empty pattern collection."""
# List of (pattern, is_negation, source_dir)
self.patterns: List[Tuple[str, bool, Path]] = []
def add_gitignore(self, gitignore_path: Path) -> None:
"""
Parse and add patterns from a .gitignore file.
Args:
gitignore_path: Path to the .gitignore file
"""
logger = get_logger()
if not gitignore_path.exists():
return
try:
source_dir = gitignore_path.parent
with open(gitignore_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.rstrip("\n\r")
# Skip empty lines and comments
if not line or line.startswith("#"):
continue
# Check for negation pattern
is_negation = line.startswith("!")
if is_negation:
line = line[1:]
# Remove leading slash (make relative to gitignore location)
if line.startswith("/"):
line = line[1:]
self.patterns.append((line, is_negation, source_dir))
logger.debug(
f"Loaded {len(self.patterns)} patterns from {gitignore_path}"
)
except Exception as e:
logger.warning(f"Error reading {gitignore_path}: {e}")
def should_ignore(self, path: Path) -> bool:
"""
Check if a path should be ignored based on gitignore patterns.
Patterns are evaluated in order, with later patterns overriding
earlier ones. Negation patterns (starting with !) un-ignore
previously matched paths.
Args:
path: Path to check
Returns:
True if the path should be ignored
"""
if not self.patterns:
return False
ignored = False
for pattern, is_negation, source_dir in self.patterns:
# Only apply pattern if path is under the source directory
try:
rel_path = path.relative_to(source_dir)
except ValueError:
# Path is not relative to this gitignore's directory
continue
rel_path_str = str(rel_path)
# Check if pattern matches
if self._match_pattern(pattern, rel_path_str, path.is_dir()):
if is_negation:
ignored = False # Negation patterns un-ignore
else:
ignored = True
return ignored
def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool:
"""
Match a gitignore pattern against a path.
Args:
pattern: The gitignore pattern
path: The relative path string to match
is_dir: Whether the path is a directory
Returns:
True if the pattern matches
"""
# Directory-only pattern (ends with /)
if pattern.endswith("/"):
if not is_dir:
return False
pattern = pattern[:-1]
# Handle ** patterns (matches any number of directories)
if "**" in pattern:
pattern_parts = pattern.split("**")
if len(pattern_parts) == 2:
prefix, suffix = pattern_parts
# Match if path starts with prefix and ends with suffix
if prefix:
if not path.startswith(prefix.rstrip("/")):
return False
if suffix:
suffix = suffix.lstrip("/")
if not (path.endswith(suffix) or f"/{suffix}" in path):
return False
return True
# Direct match using fnmatch
if fnmatch.fnmatch(path, pattern):
return True
# Match as subdirectory pattern (pattern without / matches in any directory)
if "/" not in pattern:
parts = path.split("/")
if any(fnmatch.fnmatch(part, pattern) for part in parts):
return True
return False
def clear(self) -> None:
"""Clear all loaded patterns."""
self.patterns = []
@property
def pattern_count(self) -> int:
"""Get the number of loaded patterns."""
return len(self.patterns)

1365
oai/mcp/manager.py Normal file

File diff suppressed because it is too large Load Diff

228
oai/mcp/platform.py Normal file
View File

@@ -0,0 +1,228 @@
"""
Cross-platform MCP configuration for oAI.
This module handles OS-specific configuration, path handling,
and security checks for the MCP filesystem server.
"""
import os
import platform
import subprocess
from pathlib import Path
from typing import List, Dict, Any, Optional
from oai.constants import SYSTEM_DIRS_BLACKLIST
from oai.utils.logging import get_logger
class CrossPlatformMCPConfig:
"""
Handle OS-specific MCP configuration.
Provides methods for path normalization, security validation,
and OS-specific default directories.
Attributes:
system: Operating system name
is_macos: Whether running on macOS
is_linux: Whether running on Linux
is_windows: Whether running on Windows
"""
def __init__(self):
"""Initialize platform detection."""
self.system = platform.system()
self.is_macos = self.system == "Darwin"
self.is_linux = self.system == "Linux"
self.is_windows = self.system == "Windows"
logger = get_logger()
logger.info(f"Detected OS: {self.system}")
def get_default_allowed_dirs(self) -> List[Path]:
"""
Get safe default directories for the current OS.
Returns:
List of default directories that are safe to access
"""
home = Path.home()
if self.is_macos:
return [
home / "Documents",
home / "Desktop",
home / "Downloads",
]
elif self.is_linux:
dirs = [home / "Documents"]
# Try to get XDG directories
try:
for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]:
result = subprocess.run(
["xdg-user-dir", xdg_dir],
capture_output=True,
text=True,
timeout=1
)
if result.returncode == 0:
dir_path = Path(result.stdout.strip())
if dir_path.exists():
dirs.append(dir_path)
except (subprocess.TimeoutExpired, FileNotFoundError):
# Fallback to standard locations
dirs.extend([
home / "Desktop",
home / "Downloads",
])
return list(set(dirs))
elif self.is_windows:
return [
home / "Documents",
home / "Desktop",
home / "Downloads",
]
# Fallback for unknown OS
return [home]
def get_python_command(self) -> str:
"""
Get the Python executable path.
Returns:
Path to the Python executable
"""
import sys
return sys.executable
def get_filesystem_warning(self) -> str:
"""
Get OS-specific security warning message.
Returns:
Warning message for the current OS
"""
if self.is_macos:
return """
Note: macOS Security
The Filesystem MCP server needs access to your selected folder.
You may see a security prompt - click 'Allow' to proceed.
(System Settings > Privacy & Security > Files and Folders)
"""
elif self.is_linux:
return """
Note: Linux Security
The Filesystem MCP server will access your selected folder.
Ensure oAI has appropriate file permissions.
"""
elif self.is_windows:
return """
Note: Windows Security
The Filesystem MCP server will access your selected folder.
You may need to grant file access permissions.
"""
return ""
def normalize_path(self, path: str) -> Path:
"""
Normalize a path for the current OS.
Expands user directory (~) and resolves to absolute path.
Args:
path: Path string to normalize
Returns:
Normalized absolute Path
"""
return Path(os.path.expanduser(path)).resolve()
def is_system_directory(self, path: Path) -> bool:
"""
Check if a path is a protected system directory.
Args:
path: Path to check
Returns:
True if the path is a system directory
"""
path_str = str(path)
for blocked in SYSTEM_DIRS_BLACKLIST:
if path_str.startswith(blocked):
return True
return False
def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool:
"""
Check if a path is within allowed directories.
Args:
requested_path: Path being requested
allowed_dirs: List of allowed parent directories
Returns:
True if the path is within an allowed directory
"""
try:
requested = requested_path.resolve()
for allowed in allowed_dirs:
try:
allowed_resolved = allowed.resolve()
requested.relative_to(allowed_resolved)
return True
except ValueError:
continue
return False
except Exception:
return False
def get_folder_stats(self, folder: Path) -> Dict[str, Any]:
"""
Get statistics for a folder.
Args:
folder: Path to the folder
Returns:
Dictionary with folder statistics:
- exists: Whether the folder exists
- file_count: Number of files (if exists)
- total_size: Total size in bytes (if exists)
- size_mb: Size in megabytes (if exists)
- error: Error message (if any)
"""
logger = get_logger()
try:
if not folder.exists() or not folder.is_dir():
return {"exists": False}
file_count = 0
total_size = 0
for item in folder.rglob("*"):
if item.is_file():
file_count += 1
try:
total_size += item.stat().st_size
except (OSError, PermissionError):
pass
return {
"exists": True,
"file_count": file_count,
"total_size": total_size,
"size_mb": total_size / (1024 * 1024),
}
except Exception as e:
logger.error(f"Error getting folder stats for {folder}: {e}")
return {"exists": False, "error": str(e)}

1368
oai/mcp/server.py Normal file

File diff suppressed because it is too large Load Diff

123
oai/mcp/validators.py Normal file
View File

@@ -0,0 +1,123 @@
"""
Query validation for oAI MCP database operations.
This module provides safety validation for SQL queries to ensure
only read-only operations are executed.
"""
import re
from typing import Tuple
from oai.constants import DANGEROUS_SQL_KEYWORDS
class SQLiteQueryValidator:
"""
Validate SQLite queries for read-only safety.
Ensures that only SELECT queries (including CTEs) are allowed
and blocks potentially dangerous operations like INSERT, UPDATE,
DELETE, DROP, etc.
"""
@staticmethod
def is_safe_query(query: str) -> Tuple[bool, str]:
"""
Validate that a query is a safe read-only SELECT.
The validation:
1. Checks that query starts with SELECT or WITH
2. Strips string literals before checking for dangerous keywords
3. Blocks any dangerous keywords outside of string literals
Args:
query: SQL query string to validate
Returns:
Tuple of (is_safe, error_message)
- is_safe: True if the query is safe to execute
- error_message: Description of why query is unsafe (empty if safe)
Examples:
>>> SQLiteQueryValidator.is_safe_query("SELECT * FROM users")
(True, "")
>>> SQLiteQueryValidator.is_safe_query("DELETE FROM users")
(False, "Only SELECT queries are allowed...")
>>> SQLiteQueryValidator.is_safe_query("SELECT 'DELETE' FROM users")
(True, "") # 'DELETE' is inside a string literal
"""
query_upper = query.strip().upper()
# Must start with SELECT or WITH (for CTEs)
if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")):
return False, "Only SELECT queries are allowed (including WITH/CTE)"
# Remove string literals before checking for dangerous keywords
# This prevents false positives when keywords appear in data
query_no_strings = re.sub(r"'[^']*'", "", query_upper)
query_no_strings = re.sub(r'"[^"]*"', "", query_no_strings)
# Check for dangerous keywords outside of quotes
for keyword in DANGEROUS_SQL_KEYWORDS:
if re.search(r"\b" + keyword + r"\b", query_no_strings):
return False, f"Keyword '{keyword}' not allowed in read-only mode"
return True, ""
@staticmethod
def sanitize_table_name(table_name: str) -> str:
"""
Sanitize a table name to prevent SQL injection.
Only allows alphanumeric characters and underscores.
Args:
table_name: Table name to sanitize
Returns:
Sanitized table name
Raises:
ValueError: If table name contains invalid characters
"""
# Remove any characters that aren't alphanumeric or underscore
sanitized = re.sub(r"[^\w]", "", table_name)
if not sanitized:
raise ValueError("Table name cannot be empty after sanitization")
if sanitized != table_name:
raise ValueError(
f"Table name contains invalid characters: {table_name}"
)
return sanitized
@staticmethod
def sanitize_column_name(column_name: str) -> str:
"""
Sanitize a column name to prevent SQL injection.
Only allows alphanumeric characters and underscores.
Args:
column_name: Column name to sanitize
Returns:
Sanitized column name
Raises:
ValueError: If column name contains invalid characters
"""
# Remove any characters that aren't alphanumeric or underscore
sanitized = re.sub(r"[^\w]", "", column_name)
if not sanitized:
raise ValueError("Column name cannot be empty after sanitization")
if sanitized != column_name:
raise ValueError(
f"Column name contains invalid characters: {column_name}"
)
return sanitized

32
oai/providers/__init__.py Normal file
View File

@@ -0,0 +1,32 @@
"""
Provider abstraction for oAI.
This module provides a unified interface for AI model providers,
enabling easy extension to support additional providers beyond OpenRouter.
"""
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ToolCall,
ToolFunction,
UsageStats,
ModelInfo,
ProviderCapabilities,
)
from oai.providers.openrouter import OpenRouterProvider
__all__ = [
# Base classes and types
"AIProvider",
"ChatMessage",
"ChatResponse",
"ToolCall",
"ToolFunction",
"UsageStats",
"ModelInfo",
"ProviderCapabilities",
# Provider implementations
"OpenRouterProvider",
]

413
oai/providers/base.py Normal file
View File

@@ -0,0 +1,413 @@
"""
Abstract base provider for AI model integration.
This module defines the interface that all AI providers must implement,
along with common data structures for requests and responses.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Union,
)
class MessageRole(str, Enum):
"""Message roles in a conversation."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@dataclass
class ToolFunction:
"""
Represents a function within a tool call.
Attributes:
name: The function name
arguments: JSON string of function arguments
"""
name: str
arguments: str
@dataclass
class ToolCall:
"""
Represents a tool/function call requested by the model.
Attributes:
id: Unique identifier for this tool call
type: Type of tool call (usually "function")
function: The function being called
"""
id: str
type: str
function: ToolFunction
@dataclass
class UsageStats:
"""
Token usage statistics from an API response.
Attributes:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total tokens used
total_cost_usd: Cost in USD (if available from API)
"""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
total_cost_usd: Optional[float] = None
@property
def input_tokens(self) -> int:
"""Alias for prompt_tokens."""
return self.prompt_tokens
@property
def output_tokens(self) -> int:
"""Alias for completion_tokens."""
return self.completion_tokens
@dataclass
class ChatMessage:
"""
A single message in a chat conversation.
Attributes:
role: The role of the message sender
content: Message content (text or structured content blocks)
name: Optional name for the sender
tool_calls: List of tool calls (for assistant messages)
tool_call_id: Tool call ID this message responds to (for tool messages)
"""
role: str
content: Union[str, List[Dict[str, Any]], None] = None
name: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
tool_call_id: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format for API requests."""
result: Dict[str, Any] = {"role": self.role}
if self.content is not None:
result["content"] = self.content
if self.name:
result["name"] = self.name
if self.tool_calls:
result["tool_calls"] = [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in self.tool_calls
]
if self.tool_call_id:
result["tool_call_id"] = self.tool_call_id
return result
@dataclass
class ChatResponseChoice:
"""
A single choice in a chat response.
Attributes:
index: Index of this choice
message: The response message
finish_reason: Why the response ended
"""
index: int
message: ChatMessage
finish_reason: Optional[str] = None
@dataclass
class ChatResponse:
"""
Response from a chat completion request.
Attributes:
id: Unique identifier for this response
choices: List of response choices
usage: Token usage statistics
model: Model that generated this response
created: Unix timestamp of creation
"""
id: str
choices: List[ChatResponseChoice]
usage: Optional[UsageStats] = None
model: Optional[str] = None
created: Optional[int] = None
@property
def message(self) -> Optional[ChatMessage]:
"""Get the first choice's message."""
if self.choices:
return self.choices[0].message
return None
@property
def content(self) -> Optional[str]:
"""Get the text content of the first choice."""
msg = self.message
if msg and isinstance(msg.content, str):
return msg.content
return None
@property
def tool_calls(self) -> Optional[List[ToolCall]]:
"""Get tool calls from the first choice."""
msg = self.message
if msg:
return msg.tool_calls
return None
@dataclass
class StreamChunk:
"""
A single chunk from a streaming response.
Attributes:
id: Response ID
delta_content: New content in this chunk
finish_reason: Finish reason (if this is the last chunk)
usage: Usage stats (usually in the last chunk)
error: Error message if something went wrong
"""
id: str
delta_content: Optional[str] = None
finish_reason: Optional[str] = None
usage: Optional[UsageStats] = None
error: Optional[str] = None
@dataclass
class ModelInfo:
"""
Information about an AI model.
Attributes:
id: Unique model identifier
name: Display name
description: Model description
context_length: Maximum context window size
pricing: Pricing info (input/output per million tokens)
supported_parameters: List of supported API parameters
input_modalities: Supported input types (text, image, etc.)
output_modalities: Supported output types
"""
id: str
name: str
description: str = ""
context_length: int = 0
pricing: Dict[str, float] = field(default_factory=dict)
supported_parameters: List[str] = field(default_factory=list)
input_modalities: List[str] = field(default_factory=lambda: ["text"])
output_modalities: List[str] = field(default_factory=lambda: ["text"])
def supports_images(self) -> bool:
"""Check if model supports image input."""
return "image" in self.input_modalities
def supports_tools(self) -> bool:
"""Check if model supports function calling/tools."""
return "tools" in self.supported_parameters or "functions" in self.supported_parameters
def supports_streaming(self) -> bool:
"""Check if model supports streaming responses."""
return "stream" in self.supported_parameters
def supports_online(self) -> bool:
"""Check if model supports web search (online mode)."""
return self.supports_tools()
@dataclass
class ProviderCapabilities:
"""
Capabilities supported by a provider.
Attributes:
streaming: Provider supports streaming responses
tools: Provider supports function calling
images: Provider supports image inputs
online: Provider supports web search
max_context: Maximum context length across all models
"""
streaming: bool = True
tools: bool = True
images: bool = True
online: bool = False
max_context: int = 128000
class AIProvider(ABC):
"""
Abstract base class for AI model providers.
All provider implementations must inherit from this class
and implement the required abstract methods.
"""
def __init__(self, api_key: str, base_url: Optional[str] = None):
"""
Initialize the provider.
Args:
api_key: API key for authentication
base_url: Optional custom base URL for the API
"""
self.api_key = api_key
self.base_url = base_url
@property
@abstractmethod
def name(self) -> str:
"""Get the provider name."""
pass
@property
@abstractmethod
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
pass
@abstractmethod
def list_models(self) -> List[ModelInfo]:
"""
Fetch available models from the provider.
Returns:
List of available models with their info
"""
pass
@abstractmethod
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: The model identifier
Returns:
Model information or None if not found
"""
pass
@abstractmethod
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send a chat completion request.
Args:
model: Model ID to use
messages: List of chat messages
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature
tools: List of tool definitions for function calling
tool_choice: How to handle tool selection ("auto", "none", etc.)
**kwargs: Additional provider-specific parameters
Returns:
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
"""
pass
@abstractmethod
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Send an async chat completion request.
Args:
model: Model ID to use
messages: List of chat messages
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature
tools: List of tool definitions for function calling
tool_choice: How to handle tool selection
**kwargs: Additional provider-specific parameters
Returns:
ChatResponse for non-streaming, AsyncIterator[StreamChunk] for streaming
"""
pass
@abstractmethod
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credit/balance information.
Returns:
Dict with credit info or None if not supported
"""
pass
def validate_api_key(self) -> bool:
"""
Validate that the API key is valid.
Returns:
True if API key is valid
"""
try:
self.list_models()
return True
except Exception:
return False

623
oai/providers/openrouter.py Normal file
View File

@@ -0,0 +1,623 @@
"""
OpenRouter provider implementation.
This module implements the AIProvider interface for OpenRouter,
supporting chat completions, streaming, and function calling.
"""
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
import requests
from openrouter import OpenRouter
from oai.constants import APP_NAME, APP_URL, DEFAULT_BASE_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ChatResponseChoice,
ModelInfo,
ProviderCapabilities,
StreamChunk,
ToolCall,
ToolFunction,
UsageStats,
)
from oai.utils.logging import get_logger
class OpenRouterProvider(AIProvider):
"""
OpenRouter API provider implementation.
Provides access to multiple AI models through OpenRouter's unified API,
supporting chat completions, streaming responses, and function calling.
Attributes:
client: The underlying OpenRouter client
_models_cache: Cached list of available models
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
app_name: str = APP_NAME,
app_url: str = APP_URL,
):
"""
Initialize the OpenRouter provider.
Args:
api_key: OpenRouter API key
base_url: Optional custom base URL
app_name: Application name for API headers
app_url: Application URL for API headers
"""
super().__init__(api_key, base_url or DEFAULT_BASE_URL)
self.app_name = app_name
self.app_url = app_url
self.client = OpenRouter(api_key=api_key)
self._models_cache: Optional[List[ModelInfo]] = None
self._raw_models_cache: Optional[List[Dict[str, Any]]] = None
self.logger = get_logger()
self.logger.info(f"OpenRouter provider initialized with base URL: {self.base_url}")
@property
def name(self) -> str:
"""Get the provider name."""
return "OpenRouter"
@property
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
return ProviderCapabilities(
streaming=True,
tools=True,
images=True,
online=True,
max_context=2000000, # Claude models support up to 200k
)
def _get_headers(self) -> Dict[str, str]:
"""Get standard HTTP headers for API requests."""
headers = {
"HTTP-Referer": self.app_url,
"X-Title": self.app_name,
}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
"""
Parse raw model data into ModelInfo.
Args:
model_data: Raw model data from API
Returns:
Parsed ModelInfo object
"""
architecture = model_data.get("architecture", {})
pricing_data = model_data.get("pricing", {})
# Parse pricing (convert from string to float if needed)
pricing = {}
for key in ["prompt", "completion"]:
value = pricing_data.get(key)
if value is not None:
try:
# Convert from per-token to per-million-tokens
pricing[key] = float(value) * 1_000_000
except (ValueError, TypeError):
pricing[key] = 0.0
return ModelInfo(
id=model_data.get("id", ""),
name=model_data.get("name", model_data.get("id", "")),
description=model_data.get("description", ""),
context_length=model_data.get("context_length", 0),
pricing=pricing,
supported_parameters=model_data.get("supported_parameters", []),
input_modalities=architecture.get("input_modalities", ["text"]),
output_modalities=architecture.get("output_modalities", ["text"]),
)
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
Fetch available models from OpenRouter.
Args:
filter_text_only: If True, exclude video-only models
Returns:
List of available models
Raises:
Exception: If API request fails
"""
if self._models_cache is not None:
return self._models_cache
try:
response = requests.get(
f"{self.base_url}/models",
headers=self._get_headers(),
timeout=10,
)
response.raise_for_status()
raw_models = response.json().get("data", [])
self._raw_models_cache = raw_models
models = []
for model_data in raw_models:
# Optionally filter out video-only models
if filter_text_only:
modalities = model_data.get("modalities", [])
if modalities and "video" in modalities and "text" not in modalities:
continue
models.append(self._parse_model(model_data))
self._models_cache = models
self.logger.info(f"Fetched {len(models)} models from OpenRouter")
return models
except requests.RequestException as e:
self.logger.error(f"Failed to fetch models: {e}")
raise
def get_raw_models(self) -> List[Dict[str, Any]]:
"""
Get raw model data as returned by the API.
Useful for accessing provider-specific fields not in ModelInfo.
Returns:
List of raw model dictionaries
"""
if self._raw_models_cache is None:
self.list_models()
return self._raw_models_cache or []
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: The model identifier
Returns:
Model information or None if not found
"""
models = self.list_models()
for model in models:
if model.id == model_id:
return model
return None
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for a specific model.
Args:
model_id: The model identifier
Returns:
Raw model dictionary or None if not found
"""
raw_models = self.get_raw_models()
for model in raw_models:
if model.get("id") == model_id:
return model
return None
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
"""
Convert ChatMessage objects to API format.
Args:
messages: List of ChatMessage objects
Returns:
List of message dictionaries for the API
"""
return [msg.to_dict() for msg in messages]
def _parse_usage(self, usage_data: Any) -> Optional[UsageStats]:
"""
Parse usage data from API response.
Args:
usage_data: Raw usage data from API
Returns:
Parsed UsageStats or None
"""
if not usage_data:
return None
# Handle both attribute and dict access
prompt_tokens = 0
completion_tokens = 0
total_cost = None
if hasattr(usage_data, "prompt_tokens"):
prompt_tokens = getattr(usage_data, "prompt_tokens", 0) or 0
elif isinstance(usage_data, dict):
prompt_tokens = usage_data.get("prompt_tokens", 0) or 0
if hasattr(usage_data, "completion_tokens"):
completion_tokens = getattr(usage_data, "completion_tokens", 0) or 0
elif isinstance(usage_data, dict):
completion_tokens = usage_data.get("completion_tokens", 0) or 0
# Try alternative naming (input_tokens/output_tokens)
if prompt_tokens == 0:
if hasattr(usage_data, "input_tokens"):
prompt_tokens = getattr(usage_data, "input_tokens", 0) or 0
elif isinstance(usage_data, dict):
prompt_tokens = usage_data.get("input_tokens", 0) or 0
if completion_tokens == 0:
if hasattr(usage_data, "output_tokens"):
completion_tokens = getattr(usage_data, "output_tokens", 0) or 0
elif isinstance(usage_data, dict):
completion_tokens = usage_data.get("output_tokens", 0) or 0
# Get cost if available
if hasattr(usage_data, "total_cost_usd"):
total_cost = getattr(usage_data, "total_cost_usd", None)
elif isinstance(usage_data, dict):
total_cost = usage_data.get("total_cost_usd")
return UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
total_cost_usd=float(total_cost) if total_cost else None,
)
def _parse_tool_calls(self, tool_calls_data: Any) -> Optional[List[ToolCall]]:
"""
Parse tool calls from API response.
Args:
tool_calls_data: Raw tool calls data
Returns:
List of ToolCall objects or None
"""
if not tool_calls_data:
return None
tool_calls = []
for tc in tool_calls_data:
# Handle both attribute and dict access
if hasattr(tc, "id"):
tc_id = tc.id
tc_type = getattr(tc, "type", "function")
func = tc.function
func_name = func.name
func_args = func.arguments
else:
tc_id = tc.get("id", "")
tc_type = tc.get("type", "function")
func = tc.get("function", {})
func_name = func.get("name", "")
func_args = func.get("arguments", "{}")
tool_calls.append(
ToolCall(
id=tc_id,
type=tc_type,
function=ToolFunction(name=func_name, arguments=func_args),
)
)
return tool_calls if tool_calls else None
def _parse_response(self, response: Any) -> ChatResponse:
"""
Parse API response into ChatResponse.
Args:
response: Raw API response
Returns:
Parsed ChatResponse
"""
choices = []
for choice in response.choices:
msg = choice.message
message = ChatMessage(
role=msg.role if hasattr(msg, "role") else "assistant",
content=msg.content if hasattr(msg, "content") else None,
tool_calls=self._parse_tool_calls(
getattr(msg, "tool_calls", None)
),
)
choices.append(
ChatResponseChoice(
index=choice.index if hasattr(choice, "index") else 0,
message=message,
finish_reason=getattr(choice, "finish_reason", None),
)
)
return ChatResponse(
id=response.id if hasattr(response, "id") else "",
choices=choices,
usage=self._parse_usage(getattr(response, "usage", None)),
model=getattr(response, "model", None),
created=getattr(response, "created", None),
)
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
transforms: Optional[List[str]] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send a chat completion request to OpenRouter.
Args:
model: Model ID to use
messages: List of chat messages
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature (0-2)
tools: List of tool definitions for function calling
tool_choice: How to handle tool selection ("auto", "none", etc.)
transforms: List of transforms (e.g., ["middle-out"])
**kwargs: Additional parameters
Returns:
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
"""
# Build request parameters
params: Dict[str, Any] = {
"model": model,
"messages": self._convert_messages(messages),
"stream": stream,
"http_headers": self._get_headers(),
}
# Request usage stats in streaming responses
if stream:
params["stream_options"] = {"include_usage": True}
if max_tokens is not None:
params["max_tokens"] = max_tokens
if temperature is not None:
params["temperature"] = temperature
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice or "auto"
if transforms:
params["transforms"] = transforms
# Add any additional parameters
params.update(kwargs)
self.logger.debug(f"Sending chat request to model {model}")
try:
response = self.client.chat.send(**params)
if stream:
return self._stream_response(response)
else:
return self._parse_response(response)
except Exception as e:
self.logger.error(f"Chat request failed: {e}")
raise
def _stream_response(self, response: Any) -> Iterator[StreamChunk]:
"""
Process a streaming response.
Args:
response: Streaming response from API
Yields:
StreamChunk objects
"""
last_usage = None
try:
for chunk in response:
# Check for errors
if hasattr(chunk, "error") and chunk.error:
yield StreamChunk(
id=getattr(chunk, "id", ""),
error=chunk.error.message if hasattr(chunk.error, "message") else str(chunk.error),
)
return
# Extract delta content
delta_content = None
finish_reason = None
if hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta"):
delta = choice.delta
if hasattr(delta, "content") and delta.content:
delta_content = delta.content
finish_reason = getattr(choice, "finish_reason", None)
# Track usage from last chunk
if hasattr(chunk, "usage") and chunk.usage:
last_usage = self._parse_usage(chunk.usage)
yield StreamChunk(
id=getattr(chunk, "id", ""),
delta_content=delta_content,
finish_reason=finish_reason,
usage=last_usage if finish_reason else None,
)
except Exception as e:
self.logger.error(f"Stream error: {e}")
yield StreamChunk(id="", error=str(e))
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Send an async chat completion request.
Note: Currently wraps the sync implementation.
TODO: Implement true async support when OpenRouter SDK supports it.
Args:
model: Model ID to use
messages: List of chat messages
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature
tools: List of tool definitions
tool_choice: Tool selection mode
**kwargs: Additional parameters
Returns:
ChatResponse for non-streaming, AsyncIterator for streaming
"""
# For now, use sync implementation
# TODO: Add true async when SDK supports it
result = self.chat(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
temperature=temperature,
tools=tools,
tool_choice=tool_choice,
**kwargs,
)
if stream and isinstance(result, Iterator):
# Convert sync iterator to async
async def async_iter() -> AsyncIterator[StreamChunk]:
for chunk in result:
yield chunk
return async_iter()
return result
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get OpenRouter account credit information.
Returns:
Dict with credit info:
- total_credits: Total credits purchased
- used_credits: Credits used
- credits_left: Remaining credits
Raises:
Exception: If API request fails
"""
if not self.api_key:
return None
try:
response = requests.get(
f"{self.base_url}/credits",
headers=self._get_headers(),
timeout=10,
)
response.raise_for_status()
data = response.json().get("data", {})
total_credits = float(data.get("total_credits", 0))
total_usage = float(data.get("total_usage", 0))
credits_left = total_credits - total_usage
return {
"total_credits": total_credits,
"used_credits": total_usage,
"credits_left": credits_left,
"total_credits_formatted": f"${total_credits:.2f}",
"used_credits_formatted": f"${total_usage:.2f}",
"credits_left_formatted": f"${credits_left:.2f}",
}
except Exception as e:
self.logger.error(f"Failed to fetch credits: {e}")
return None
def clear_cache(self) -> None:
"""Clear the models cache to force a refresh."""
self._models_cache = None
self._raw_models_cache = None
self.logger.debug("Models cache cleared")
def get_effective_model_id(self, model_id: str, online_enabled: bool) -> str:
"""
Get the effective model ID with online suffix if needed.
Args:
model_id: Base model ID
online_enabled: Whether online mode is enabled
Returns:
Model ID with :online suffix if applicable
"""
if online_enabled and not model_id.endswith(":online"):
return f"{model_id}:online"
return model_id
def estimate_cost(
self,
model_id: str,
input_tokens: int,
output_tokens: int,
) -> float:
"""
Estimate the cost for a completion.
Args:
model_id: Model ID
input_tokens: Number of input tokens
output_tokens: Number of output tokens
Returns:
Estimated cost in USD
"""
model = self.get_model(model_id)
if model and model.pricing:
input_cost = model.pricing.get("prompt", 0) * input_tokens / 1_000_000
output_cost = model.pricing.get("completion", 0) * output_tokens / 1_000_000
return input_cost + output_cost
# Fallback to default pricing if model not found
from oai.constants import MODEL_PRICING
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
return input_cost + output_cost

2
oai/py.typed Normal file
View File

@@ -0,0 +1,2 @@
# Marker file for PEP 561
# This package supports type checking

51
oai/ui/__init__.py Normal file
View File

@@ -0,0 +1,51 @@
"""
UI utilities for oAI.
This module provides rich terminal UI components and display helpers
for the chat application.
"""
from oai.ui.console import (
console,
clear_screen,
display_panel,
display_table,
display_markdown,
print_error,
print_warning,
print_success,
print_info,
)
from oai.ui.tables import (
create_model_table,
create_stats_table,
create_help_table,
display_paginated_table,
)
from oai.ui.prompts import (
prompt_confirm,
prompt_choice,
prompt_input,
)
__all__ = [
# Console utilities
"console",
"clear_screen",
"display_panel",
"display_table",
"display_markdown",
"print_error",
"print_warning",
"print_success",
"print_info",
# Table utilities
"create_model_table",
"create_stats_table",
"create_help_table",
"display_paginated_table",
# Prompt utilities
"prompt_confirm",
"prompt_choice",
"prompt_input",
]

242
oai/ui/console.py Normal file
View File

@@ -0,0 +1,242 @@
"""
Console utilities for oAI.
This module provides the Rich console instance and common display functions
for formatted terminal output.
"""
from typing import Any, Optional
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
# Global console instance for the application
console = Console()
def clear_screen() -> None:
"""
Clear the terminal screen.
Uses ANSI escape codes for fast clearing, with a fallback
for terminals that don't support them.
"""
try:
print("\033[H\033[J", end="", flush=True)
except Exception:
# Fallback: print many newlines
print("\n" * 100)
def display_panel(
content: Any,
title: Optional[str] = None,
subtitle: Optional[str] = None,
border_style: str = "green",
title_align: str = "left",
subtitle_align: str = "right",
) -> None:
"""
Display content in a bordered panel.
Args:
content: Content to display (string, Table, or Markdown)
title: Optional panel title
subtitle: Optional panel subtitle
border_style: Border color/style
title_align: Title alignment ("left", "center", "right")
subtitle_align: Subtitle alignment
"""
panel = Panel(
content,
title=title,
subtitle=subtitle,
border_style=border_style,
title_align=title_align,
subtitle_align=subtitle_align,
)
console.print(panel)
def display_table(
table: Table,
title: Optional[str] = None,
subtitle: Optional[str] = None,
) -> None:
"""
Display a table with optional title panel.
Args:
table: Rich Table to display
title: Optional panel title
subtitle: Optional panel subtitle
"""
if title:
display_panel(table, title=title, subtitle=subtitle)
else:
console.print(table)
def display_markdown(
content: str,
panel: bool = False,
title: Optional[str] = None,
) -> None:
"""
Display markdown-formatted content.
Args:
content: Markdown text to display
panel: Whether to wrap in a panel
title: Optional panel title (if panel=True)
"""
md = Markdown(content)
if panel:
display_panel(md, title=title)
else:
console.print(md)
def print_error(message: str, prefix: str = "Error:") -> None:
"""
Print an error message in red.
Args:
message: Error message to display
prefix: Prefix before the message (default: "Error:")
"""
console.print(f"[bold red]{prefix}[/] {message}")
def print_warning(message: str, prefix: str = "Warning:") -> None:
"""
Print a warning message in yellow.
Args:
message: Warning message to display
prefix: Prefix before the message (default: "Warning:")
"""
console.print(f"[bold yellow]{prefix}[/] {message}")
def print_success(message: str, prefix: str = "") -> None:
"""
Print a success message in green.
Args:
message: Success message to display
prefix: Prefix before the message (default: "")
"""
console.print(f"[bold green]{prefix}[/] {message}")
def print_info(message: str, dim: bool = False) -> None:
"""
Print an informational message in cyan.
Args:
message: Info message to display
dim: Whether to dim the message
"""
if dim:
console.print(f"[dim cyan]{message}[/]")
else:
console.print(f"[bold cyan]{message}[/]")
def print_metrics(
tokens: int,
cost: float,
time_seconds: float,
context_info: str = "",
online: bool = False,
mcp_mode: Optional[str] = None,
tool_loops: int = 0,
session_tokens: int = 0,
session_cost: float = 0.0,
) -> None:
"""
Print formatted metrics for a response.
Args:
tokens: Total tokens used
cost: Cost in USD
time_seconds: Response time
context_info: Context information string
online: Whether online mode is active
mcp_mode: MCP mode ("files", "database", or None)
tool_loops: Number of tool call loops
session_tokens: Total session tokens
session_cost: Total session cost
"""
parts = [
f"📊 Metrics: {tokens} tokens",
f"${cost:.4f}",
f"{time_seconds:.2f}s",
]
if context_info:
parts.append(context_info)
if online:
parts.append("🌐")
if mcp_mode == "files":
parts.append("🔧")
elif mcp_mode == "database":
parts.append("🗄️")
if tool_loops > 0:
parts.append(f"({tool_loops} tool loop(s))")
parts.append(f"Session: {session_tokens} tokens")
parts.append(f"${session_cost:.4f}")
console.print(f"\n[dim blue]{' | '.join(parts)}[/]")
def format_size(size_bytes: int) -> str:
"""
Format a size in bytes to a human-readable string.
Args:
size_bytes: Size in bytes
Returns:
Formatted size string (e.g., "1.5 MB")
"""
for unit in ["B", "KB", "MB", "GB", "TB"]:
if abs(size_bytes) < 1024.0:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.1f} PB"
def format_tokens(tokens: int) -> str:
"""
Format token count with thousands separators.
Args:
tokens: Number of tokens
Returns:
Formatted token string (e.g., "1,234,567")
"""
return f"{tokens:,}"
def format_cost(cost: float, precision: int = 4) -> str:
"""
Format cost in USD.
Args:
cost: Cost in dollars
precision: Decimal places
Returns:
Formatted cost string (e.g., "$0.0123")
"""
return f"${cost:.{precision}f}"

274
oai/ui/prompts.py Normal file
View File

@@ -0,0 +1,274 @@
"""
Prompt utilities for oAI.
This module provides functions for gathering user input
through confirmations, choices, and text prompts.
"""
from typing import List, Optional, TypeVar
import typer
from oai.ui.console import console
T = TypeVar("T")
def prompt_confirm(
message: str,
default: bool = False,
abort: bool = False,
) -> bool:
"""
Prompt the user for a yes/no confirmation.
Args:
message: The question to ask
default: Default value if user presses Enter
abort: Whether to abort on "no" response
Returns:
True if user confirms, False otherwise
"""
try:
return typer.confirm(message, default=default, abort=abort)
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Cancelled[/]")
return False
def prompt_choice(
message: str,
choices: List[str],
default: Optional[str] = None,
) -> Optional[str]:
"""
Prompt the user to select from a list of choices.
Args:
message: The question to ask
choices: List of valid choices
default: Default choice if user presses Enter
Returns:
Selected choice or None if cancelled
"""
# Display choices
console.print(f"\n[bold cyan]{message}[/]")
for i, choice in enumerate(choices, 1):
default_marker = " [default]" if choice == default else ""
console.print(f" {i}. {choice}{default_marker}")
try:
response = input("\nEnter number or value: ").strip()
if not response and default:
return default
# Try as number first
try:
index = int(response) - 1
if 0 <= index < len(choices):
return choices[index]
except ValueError:
pass
# Try as exact match
if response in choices:
return response
# Try case-insensitive match
response_lower = response.lower()
for choice in choices:
if choice.lower() == response_lower:
return choice
console.print(f"[red]Invalid choice: {response}[/]")
return None
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Cancelled[/]")
return None
def prompt_input(
message: str,
default: Optional[str] = None,
password: bool = False,
required: bool = False,
) -> Optional[str]:
"""
Prompt the user for text input.
Args:
message: The prompt message
default: Default value if user presses Enter
password: Whether to hide input (for sensitive data)
required: Whether input is required (loops until provided)
Returns:
User input or default, None if cancelled
"""
prompt_text = message
if default:
prompt_text += f" [{default}]"
prompt_text += ": "
try:
while True:
if password:
import getpass
response = getpass.getpass(prompt_text)
else:
response = input(prompt_text).strip()
if not response:
if default:
return default
if required:
console.print("[yellow]Input required[/]")
continue
return None
return response
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Cancelled[/]")
return None
def prompt_number(
message: str,
min_value: Optional[int] = None,
max_value: Optional[int] = None,
default: Optional[int] = None,
) -> Optional[int]:
"""
Prompt the user for a numeric input.
Args:
message: The prompt message
min_value: Minimum allowed value
max_value: Maximum allowed value
default: Default value if user presses Enter
Returns:
Integer value or None if cancelled
"""
prompt_text = message
if default is not None:
prompt_text += f" [{default}]"
prompt_text += ": "
try:
while True:
response = input(prompt_text).strip()
if not response:
if default is not None:
return default
return None
try:
value = int(response)
except ValueError:
console.print("[red]Please enter a valid number[/]")
continue
if min_value is not None and value < min_value:
console.print(f"[red]Value must be at least {min_value}[/]")
continue
if max_value is not None and value > max_value:
console.print(f"[red]Value must be at most {max_value}[/]")
continue
return value
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Cancelled[/]")
return None
def prompt_selection(
items: List[T],
message: str = "Select an item",
display_func: Optional[callable] = None,
allow_cancel: bool = True,
) -> Optional[T]:
"""
Prompt the user to select an item from a list.
Args:
items: List of items to choose from
message: The selection prompt
display_func: Function to convert item to display string
allow_cancel: Whether to allow cancellation
Returns:
Selected item or None if cancelled
"""
if not items:
console.print("[yellow]No items to select[/]")
return None
display = display_func or str
console.print(f"\n[bold cyan]{message}[/]")
for i, item in enumerate(items, 1):
console.print(f" {i}. {display(item)}")
if allow_cancel:
console.print(f" 0. Cancel")
try:
while True:
response = input("\nEnter number: ").strip()
try:
index = int(response)
except ValueError:
console.print("[red]Please enter a valid number[/]")
continue
if allow_cancel and index == 0:
return None
if 1 <= index <= len(items):
return items[index - 1]
console.print(f"[red]Please enter a number between 1 and {len(items)}[/]")
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Cancelled[/]")
return None
def prompt_copy_response(response: str) -> bool:
"""
Prompt user to copy a response to clipboard.
Args:
response: The response text
Returns:
True if copied, False otherwise
"""
try:
copy_choice = input("💾 Type 'c' to copy response, or press Enter to continue: ").strip().lower()
if copy_choice == "c":
try:
import pyperclip
pyperclip.copy(response)
console.print("[bold green]✅ Response copied to clipboard![/]")
return True
except ImportError:
console.print("[yellow]pyperclip not installed - cannot copy to clipboard[/]")
except Exception as e:
console.print(f"[red]Failed to copy: {e}[/]")
except (EOFError, KeyboardInterrupt):
pass
return False

373
oai/ui/tables.py Normal file
View File

@@ -0,0 +1,373 @@
"""
Table utilities for oAI.
This module provides functions for creating and displaying
formatted tables with pagination support.
"""
import os
import sys
from typing import Any, Dict, List, Optional
from rich.panel import Panel
from rich.table import Table
from oai.ui.console import clear_screen, console
def create_model_table(
models: List[Dict[str, Any]],
show_capabilities: bool = True,
) -> Table:
"""
Create a table displaying available AI models.
Args:
models: List of model dictionaries
show_capabilities: Whether to show capability columns
Returns:
Rich Table with model information
"""
if show_capabilities:
table = Table(
"No.",
"Model ID",
"Context",
"Image",
"Online",
"Tools",
show_header=True,
header_style="bold magenta",
)
else:
table = Table(
"No.",
"Model ID",
"Context",
show_header=True,
header_style="bold magenta",
)
for i, model in enumerate(models, 1):
model_id = model.get("id", "Unknown")
context = model.get("context_length", 0)
context_str = f"{context:,}" if context else "-"
if show_capabilities:
# Get modalities and parameters
architecture = model.get("architecture", {})
input_modalities = architecture.get("input_modalities", [])
supported_params = model.get("supported_parameters", [])
has_image = "" if "image" in input_modalities else "-"
has_online = "" if "tools" in supported_params else "-"
has_tools = "" if "tools" in supported_params or "functions" in supported_params else "-"
table.add_row(
str(i),
model_id,
context_str,
has_image,
has_online,
has_tools,
)
else:
table.add_row(str(i), model_id, context_str)
return table
def create_stats_table(stats: Dict[str, Any]) -> Table:
"""
Create a table displaying session statistics.
Args:
stats: Dictionary with statistics data
Returns:
Rich Table with stats
"""
table = Table(
"Metric",
"Value",
show_header=True,
header_style="bold magenta",
)
# Token stats
if "input_tokens" in stats:
table.add_row("Input Tokens", f"{stats['input_tokens']:,}")
if "output_tokens" in stats:
table.add_row("Output Tokens", f"{stats['output_tokens']:,}")
if "total_tokens" in stats:
table.add_row("Total Tokens", f"{stats['total_tokens']:,}")
# Cost stats
if "total_cost" in stats:
table.add_row("Total Cost", f"${stats['total_cost']:.4f}")
if "avg_cost" in stats:
table.add_row("Avg Cost/Message", f"${stats['avg_cost']:.4f}")
# Message stats
if "message_count" in stats:
table.add_row("Messages", str(stats["message_count"]))
# Credits
if "credits_left" in stats:
table.add_row("Credits Left", stats["credits_left"])
return table
def create_help_table(commands: Dict[str, Dict[str, str]]) -> Table:
"""
Create a help table for commands.
Args:
commands: Dictionary of command info
Returns:
Rich Table with command help
"""
table = Table(
"Command",
"Description",
"Example",
show_header=True,
header_style="bold magenta",
show_lines=False,
)
for cmd, info in commands.items():
description = info.get("description", "")
example = info.get("example", "")
table.add_row(cmd, description, example)
return table
def create_folder_table(
folders: List[Dict[str, Any]],
gitignore_info: str = "",
) -> Table:
"""
Create a table for MCP folder listing.
Args:
folders: List of folder dictionaries
gitignore_info: Optional gitignore status info
Returns:
Rich Table with folder information
"""
table = Table(
"No.",
"Path",
"Files",
"Size",
show_header=True,
header_style="bold magenta",
)
for folder in folders:
number = str(folder.get("number", ""))
path = folder.get("path", "")
if folder.get("exists", True):
files = f"📁 {folder.get('file_count', 0)}"
size = f"{folder.get('size_mb', 0):.1f} MB"
else:
files = "[red]Not found[/red]"
size = "-"
table.add_row(number, path, files, size)
return table
def create_database_table(databases: List[Dict[str, Any]]) -> Table:
"""
Create a table for MCP database listing.
Args:
databases: List of database dictionaries
Returns:
Rich Table with database information
"""
table = Table(
"No.",
"Name",
"Tables",
"Size",
"Status",
show_header=True,
header_style="bold magenta",
)
for db in databases:
number = str(db.get("number", ""))
name = db.get("name", "")
table_count = f"{db.get('table_count', 0)} tables"
size = f"{db.get('size_mb', 0):.1f} MB"
if db.get("warning"):
status = f"[red]{db['warning']}[/red]"
else:
status = "[green]✓[/green]"
table.add_row(number, name, table_count, size, status)
return table
def display_paginated_table(
table: Table,
title: str,
terminal_height: Optional[int] = None,
) -> None:
"""
Display a table with pagination for large datasets.
Allows navigating through pages with keyboard input.
Press SPACE for next page, any other key to exit.
Args:
table: Rich Table to display
title: Title for the table
terminal_height: Override terminal height (auto-detected if None)
"""
# Get terminal dimensions
try:
term_height = terminal_height or os.get_terminal_size().lines - 8
except OSError:
term_height = 20
# Render table to segments
from rich.segment import Segment
segments = list(console.render(table))
# Group segments into lines
current_line_segments: List[Segment] = []
all_lines: List[List[Segment]] = []
for segment in segments:
if segment.text == "\n":
all_lines.append(current_line_segments)
current_line_segments = []
else:
current_line_segments.append(segment)
if current_line_segments:
all_lines.append(current_line_segments)
total_lines = len(all_lines)
# If table fits in one screen, just display it
if total_lines <= term_height:
console.print(Panel(table, title=title, title_align="left"))
return
# Extract header and footer lines
header_lines: List[List[Segment]] = []
data_lines: List[List[Segment]] = []
footer_line: List[Segment] = []
# Find header end (line after the header text with border)
header_end_index = 0
found_header_text = False
for i, line_segments in enumerate(all_lines):
has_header_style = any(
seg.style and ("bold" in str(seg.style) or "magenta" in str(seg.style))
for seg in line_segments
)
if has_header_style:
found_header_text = True
if found_header_text and i > 0:
line_text = "".join(seg.text for seg in line_segments)
if any(char in line_text for char in ["", "", "", "", "", ""]):
header_end_index = i
break
# Extract footer (bottom border)
if all_lines:
last_line_text = "".join(seg.text for seg in all_lines[-1])
if any(char in last_line_text for char in ["", "", "", "", "", ""]):
footer_line = all_lines[-1]
all_lines = all_lines[:-1]
# Split into header and data
if header_end_index > 0:
header_lines = all_lines[: header_end_index + 1]
data_lines = all_lines[header_end_index + 1 :]
else:
header_lines = all_lines[: min(3, len(all_lines))]
data_lines = all_lines[min(3, len(all_lines)) :]
lines_per_page = term_height - len(header_lines)
current_line = 0
page_number = 1
# Paginate
while current_line < len(data_lines):
clear_screen()
console.print(f"[bold cyan]{title} (Page {page_number})[/]")
# Print header
for line_segments in header_lines:
for segment in line_segments:
console.print(segment.text, style=segment.style, end="")
console.print()
# Print data rows for this page
end_line = min(current_line + lines_per_page, len(data_lines))
for line_segments in data_lines[current_line:end_line]:
for segment in line_segments:
console.print(segment.text, style=segment.style, end="")
console.print()
# Print footer
if footer_line:
for segment in footer_line:
console.print(segment.text, style=segment.style, end="")
console.print()
current_line = end_line
page_number += 1
# Prompt for next page
if current_line < len(data_lines):
console.print(
f"\n[dim yellow]--- Press SPACE for next page, "
f"or any other key to finish (Page {page_number - 1}, "
f"showing {end_line}/{len(data_lines)} data rows) ---[/dim yellow]"
)
try:
import termios
import tty
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(fd)
char = sys.stdin.read(1)
if char != " ":
break
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
except (ImportError, OSError, AttributeError):
# Fallback for non-Unix systems
try:
user_input = input()
if user_input.strip():
break
except (EOFError, KeyboardInterrupt):
break

20
oai/utils/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
"""
Utility modules for oAI.
This package provides common utilities used throughout the application
including logging, file handling, and export functionality.
"""
from oai.utils.logging import setup_logging, get_logger
from oai.utils.files import read_file_safe, is_binary_file
from oai.utils.export import export_as_markdown, export_as_json, export_as_html
__all__ = [
"setup_logging",
"get_logger",
"read_file_safe",
"is_binary_file",
"export_as_markdown",
"export_as_json",
"export_as_html",
]

248
oai/utils/export.py Normal file
View File

@@ -0,0 +1,248 @@
"""
Export utilities for oAI.
This module provides functions for exporting conversation history
in various formats including Markdown, JSON, and HTML.
"""
import json
import datetime
from typing import List, Dict
from html import escape as html_escape
from oai.constants import APP_VERSION, APP_URL
def export_as_markdown(
session_history: List[Dict[str, str]],
session_system_prompt: str = ""
) -> str:
"""
Export conversation history as Markdown.
Args:
session_history: List of message dictionaries with 'prompt' and 'response'
session_system_prompt: Optional system prompt to include
Returns:
Markdown formatted string
"""
lines = ["# Conversation Export", ""]
if session_system_prompt:
lines.extend([f"**System Prompt:** {session_system_prompt}", ""])
lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
lines.append(f"**Messages:** {len(session_history)}")
lines.append("")
lines.append("---")
lines.append("")
for i, entry in enumerate(session_history, 1):
lines.append(f"## Message {i}")
lines.append("")
lines.append("**User:**")
lines.append("")
lines.append(entry.get("prompt", ""))
lines.append("")
lines.append("**Assistant:**")
lines.append("")
lines.append(entry.get("response", ""))
lines.append("")
lines.append("---")
lines.append("")
lines.append(f"*Exported from oAI v{APP_VERSION} - {APP_URL}*")
return "\n".join(lines)
def export_as_json(
session_history: List[Dict[str, str]],
session_system_prompt: str = ""
) -> str:
"""
Export conversation history as JSON.
Args:
session_history: List of message dictionaries
session_system_prompt: Optional system prompt to include
Returns:
JSON formatted string
"""
export_data = {
"export_date": datetime.datetime.now().isoformat(),
"app_version": APP_VERSION,
"system_prompt": session_system_prompt,
"message_count": len(session_history),
"messages": [
{
"index": i + 1,
"prompt": entry.get("prompt", ""),
"response": entry.get("response", ""),
"prompt_tokens": entry.get("prompt_tokens", 0),
"completion_tokens": entry.get("completion_tokens", 0),
"cost": entry.get("msg_cost", 0.0),
}
for i, entry in enumerate(session_history)
],
"totals": {
"prompt_tokens": sum(e.get("prompt_tokens", 0) for e in session_history),
"completion_tokens": sum(e.get("completion_tokens", 0) for e in session_history),
"total_cost": sum(e.get("msg_cost", 0.0) for e in session_history),
}
}
return json.dumps(export_data, indent=2, ensure_ascii=False)
def export_as_html(
session_history: List[Dict[str, str]],
session_system_prompt: str = ""
) -> str:
"""
Export conversation history as styled HTML.
Args:
session_history: List of message dictionaries
session_system_prompt: Optional system prompt to include
Returns:
HTML formatted string with embedded CSS
"""
html_parts = [
"<!DOCTYPE html>",
"<html>",
"<head>",
" <meta charset='UTF-8'>",
" <meta name='viewport' content='width=device-width, initial-scale=1.0'>",
" <title>Conversation Export - oAI</title>",
" <style>",
" * { box-sizing: border-box; }",
" body {",
" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;",
" max-width: 900px;",
" margin: 40px auto;",
" padding: 20px;",
" background: #f5f5f5;",
" color: #333;",
" }",
" .header {",
" background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);",
" color: white;",
" padding: 30px;",
" border-radius: 10px;",
" margin-bottom: 30px;",
" box-shadow: 0 4px 6px rgba(0,0,0,0.1);",
" }",
" .header h1 {",
" margin: 0 0 10px 0;",
" font-size: 2em;",
" }",
" .export-info {",
" opacity: 0.9;",
" font-size: 0.95em;",
" margin: 5px 0;",
" }",
" .system-prompt {",
" background: #fff3cd;",
" padding: 20px;",
" border-radius: 8px;",
" margin-bottom: 25px;",
" border-left: 5px solid #ffc107;",
" box-shadow: 0 2px 4px rgba(0,0,0,0.05);",
" }",
" .system-prompt strong {",
" color: #856404;",
" display: block;",
" margin-bottom: 10px;",
" font-size: 1.1em;",
" }",
" .message-container { margin-bottom: 20px; }",
" .message {",
" background: white;",
" padding: 20px;",
" border-radius: 8px;",
" box-shadow: 0 2px 4px rgba(0,0,0,0.08);",
" margin-bottom: 12px;",
" }",
" .user-message { border-left: 5px solid #10b981; }",
" .assistant-message { border-left: 5px solid #3b82f6; }",
" .role {",
" font-weight: bold;",
" margin-bottom: 12px;",
" font-size: 1.05em;",
" text-transform: uppercase;",
" letter-spacing: 0.5px;",
" }",
" .user-role { color: #10b981; }",
" .assistant-role { color: #3b82f6; }",
" .content {",
" line-height: 1.8;",
" white-space: pre-wrap;",
" color: #333;",
" }",
" .message-number {",
" color: #6b7280;",
" font-size: 0.85em;",
" margin-bottom: 15px;",
" font-weight: 600;",
" }",
" .footer {",
" text-align: center;",
" margin-top: 40px;",
" padding: 20px;",
" color: #6b7280;",
" font-size: 0.9em;",
" }",
" .footer a { color: #667eea; text-decoration: none; }",
" .footer a:hover { text-decoration: underline; }",
" @media print {",
" body { background: white; }",
" .message { break-inside: avoid; }",
" }",
" </style>",
"</head>",
"<body>",
" <div class='header'>",
" <h1>Conversation Export</h1>",
f" <div class='export-info'>Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</div>",
f" <div class='export-info'>Total Messages: {len(session_history)}</div>",
" </div>",
]
if session_system_prompt:
html_parts.extend([
" <div class='system-prompt'>",
" <strong>System Prompt</strong>",
f" <div>{html_escape(session_system_prompt)}</div>",
" </div>",
])
for i, entry in enumerate(session_history, 1):
prompt = html_escape(entry.get("prompt", ""))
response = html_escape(entry.get("response", ""))
html_parts.extend([
" <div class='message-container'>",
f" <div class='message-number'>Message {i} of {len(session_history)}</div>",
" <div class='message user-message'>",
" <div class='role user-role'>User</div>",
f" <div class='content'>{prompt}</div>",
" </div>",
" <div class='message assistant-message'>",
" <div class='role assistant-role'>Assistant</div>",
f" <div class='content'>{response}</div>",
" </div>",
" </div>",
])
html_parts.extend([
" <div class='footer'>",
f" <p>Generated by oAI v{APP_VERSION} &bull; <a href='{APP_URL}'>{APP_URL}</a></p>",
" </div>",
"</body>",
"</html>",
])
return "\n".join(html_parts)

323
oai/utils/files.py Normal file
View File

@@ -0,0 +1,323 @@
"""
File handling utilities for oAI.
This module provides safe file reading, type detection, and other
file-related operations used throughout the application.
"""
import os
import mimetypes
import base64
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
from oai.constants import (
MAX_FILE_SIZE,
CONTENT_TRUNCATION_THRESHOLD,
SUPPORTED_CODE_EXTENSIONS,
ALLOWED_FILE_EXTENSIONS,
)
from oai.utils.logging import get_logger
def is_binary_file(file_path: Path) -> bool:
"""
Check if a file appears to be binary.
Args:
file_path: Path to the file to check
Returns:
True if the file appears to be binary, False otherwise
"""
try:
with open(file_path, "rb") as f:
# Read first 8KB to check for binary content
chunk = f.read(8192)
# Check for null bytes (common in binary files)
if b"\x00" in chunk:
return True
# Try to decode as UTF-8
try:
chunk.decode("utf-8")
return False
except UnicodeDecodeError:
return True
except Exception:
return True
def get_file_type(file_path: Path) -> Tuple[Optional[str], str]:
"""
Determine the MIME type and category of a file.
Args:
file_path: Path to the file
Returns:
Tuple of (mime_type, category) where category is one of:
'image', 'pdf', 'code', 'text', 'binary', 'unknown'
"""
mime_type, _ = mimetypes.guess_type(str(file_path))
ext = file_path.suffix.lower()
if mime_type and mime_type.startswith("image/"):
return mime_type, "image"
elif mime_type == "application/pdf" or ext == ".pdf":
return mime_type or "application/pdf", "pdf"
elif ext in SUPPORTED_CODE_EXTENSIONS:
return mime_type or "text/plain", "code"
elif mime_type and mime_type.startswith("text/"):
return mime_type, "text"
elif is_binary_file(file_path):
return mime_type, "binary"
else:
return mime_type, "unknown"
def read_file_safe(
file_path: Path,
max_size: int = MAX_FILE_SIZE,
truncate_threshold: int = CONTENT_TRUNCATION_THRESHOLD
) -> Dict[str, Any]:
"""
Safely read a file with size limits and truncation support.
Args:
file_path: Path to the file to read
max_size: Maximum file size to read (bytes)
truncate_threshold: Threshold for truncating large files
Returns:
Dictionary containing:
- content: File content (text or base64)
- size: File size in bytes
- truncated: Whether content was truncated
- encoding: 'text', 'base64', or None on error
- error: Error message if reading failed
"""
logger = get_logger()
try:
path = Path(file_path).resolve()
if not path.exists():
return {
"content": None,
"size": 0,
"truncated": False,
"encoding": None,
"error": f"File not found: {path}"
}
if not path.is_file():
return {
"content": None,
"size": 0,
"truncated": False,
"encoding": None,
"error": f"Not a file: {path}"
}
file_size = path.stat().st_size
if file_size > max_size:
return {
"content": None,
"size": file_size,
"truncated": False,
"encoding": None,
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max: {max_size / (1024*1024):.0f}MB)"
}
# Try to read as text first
try:
content = path.read_text(encoding="utf-8")
# Check if truncation is needed
if file_size > truncate_threshold:
lines = content.split("\n")
total_lines = len(lines)
# Keep first 500 lines and last 100 lines
head_lines = 500
tail_lines = 100
if total_lines > (head_lines + tail_lines):
truncated_content = (
"\n".join(lines[:head_lines]) +
f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted] ...\n\n" +
"\n".join(lines[-tail_lines:])
)
logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)")
return {
"content": truncated_content,
"size": file_size,
"truncated": True,
"total_lines": total_lines,
"lines_shown": head_lines + tail_lines,
"encoding": "text",
"error": None
}
logger.info(f"Read file: {path} ({file_size} bytes)")
return {
"content": content,
"size": file_size,
"truncated": False,
"encoding": "text",
"error": None
}
except UnicodeDecodeError:
# File is binary, return base64 encoded
with open(path, "rb") as f:
binary_data = f.read()
b64_content = base64.b64encode(binary_data).decode("utf-8")
logger.info(f"Read binary file: {path} ({file_size} bytes)")
return {
"content": b64_content,
"size": file_size,
"truncated": False,
"encoding": "base64",
"error": None
}
except PermissionError as e:
return {
"content": None,
"size": 0,
"truncated": False,
"encoding": None,
"error": f"Permission denied: {e}"
}
except Exception as e:
logger.error(f"Error reading file {file_path}: {e}")
return {
"content": None,
"size": 0,
"truncated": False,
"encoding": None,
"error": str(e)
}
def get_file_extension(file_path: Path) -> str:
"""
Get the lowercase file extension.
Args:
file_path: Path to the file
Returns:
Lowercase extension including the dot (e.g., '.py')
"""
return file_path.suffix.lower()
def is_allowed_extension(file_path: Path) -> bool:
"""
Check if a file has an allowed extension for attachment.
Args:
file_path: Path to the file
Returns:
True if the extension is allowed, False otherwise
"""
return get_file_extension(file_path) in ALLOWED_FILE_EXTENSIONS
def format_file_size(size_bytes: int) -> str:
"""
Format a file size in human-readable format.
Args:
size_bytes: Size in bytes
Returns:
Formatted string (e.g., '1.5 MB', '512 KB')
"""
for unit in ["B", "KB", "MB", "GB", "TB"]:
if abs(size_bytes) < 1024:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024
return f"{size_bytes:.1f} PB"
def prepare_file_attachment(
file_path: Path,
model_capabilities: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Prepare a file for attachment to an API request.
Args:
file_path: Path to the file
model_capabilities: Model capability information
Returns:
Content block dictionary for the API, or None if unsupported
"""
logger = get_logger()
path = Path(file_path).resolve()
if not path.exists():
logger.warning(f"File not found: {path}")
return None
mime_type, category = get_file_type(path)
file_size = path.stat().st_size
if file_size > MAX_FILE_SIZE:
logger.warning(f"File too large: {path} ({format_file_size(file_size)})")
return None
try:
with open(path, "rb") as f:
file_data = f.read()
if category == "image":
# Check if model supports images
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
if "image" not in input_modalities:
logger.warning(f"Model does not support images")
return None
b64_data = base64.b64encode(file_data).decode("utf-8")
return {
"type": "image_url",
"image_url": {"url": f"data:{mime_type};base64,{b64_data}"}
}
elif category == "pdf":
# Check if model supports PDFs
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"])
if not supports_pdf:
logger.warning(f"Model does not support PDFs")
return None
b64_data = base64.b64encode(file_data).decode("utf-8")
return {
"type": "image_url",
"image_url": {"url": f"data:application/pdf;base64,{b64_data}"}
}
elif category in ("code", "text"):
text_content = file_data.decode("utf-8")
return {
"type": "text",
"text": f"File: {path.name}\n\n{text_content}"
}
else:
logger.warning(f"Unsupported file type: {category} ({mime_type})")
return None
except UnicodeDecodeError:
logger.error(f"Cannot decode file as UTF-8: {path}")
return None
except Exception as e:
logger.error(f"Error preparing file attachment {path}: {e}")
return None

297
oai/utils/logging.py Normal file
View File

@@ -0,0 +1,297 @@
"""
Logging configuration for oAI.
This module provides centralized logging setup with Rich formatting,
file rotation, and configurable log levels.
"""
import io
import os
import glob
import logging
import datetime
import shutil
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Optional
from rich.console import Console
from rich.logging import RichHandler
from oai.constants import (
LOG_FILE,
CONFIG_DIR,
DEFAULT_LOG_MAX_SIZE_MB,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_LEVEL,
VALID_LOG_LEVELS,
)
class RotatingRichHandler(RotatingFileHandler):
"""
Custom log handler combining file rotation with Rich formatting.
This handler writes Rich-formatted log output to a rotating file,
providing colored and formatted logs even in file output while
managing file size and backups automatically.
"""
def __init__(self, *args, **kwargs):
"""Initialize the handler with Rich console for formatting."""
super().__init__(*args, **kwargs)
# Create an internal console for Rich formatting
self.rich_console = Console(
file=io.StringIO(),
width=120,
force_terminal=False
)
self.rich_handler = RichHandler(
console=self.rich_console,
show_time=True,
show_path=True,
rich_tracebacks=True,
tracebacks_suppress=["requests", "openrouter", "urllib3", "httpx", "openai"]
)
def emit(self, record: logging.LogRecord) -> None:
"""
Emit a log record with Rich formatting.
Args:
record: The log record to emit
"""
try:
# Format with Rich
self.rich_handler.emit(record)
output = self.rich_console.file.getvalue()
self.rich_console.file.seek(0)
self.rich_console.file.truncate(0)
if output:
self.stream.write(output)
self.flush()
except Exception:
self.handleError(record)
class LoggingManager:
"""
Manages application logging configuration.
Provides methods to setup, configure, and manage logging with
support for runtime reconfiguration and level changes.
"""
def __init__(self):
"""Initialize the logging manager."""
self.handler: Optional[RotatingRichHandler] = None
self.app_logger: Optional[logging.Logger] = None
self.max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
self.backup_count: int = DEFAULT_LOG_BACKUP_COUNT
self.log_level: str = DEFAULT_LOG_LEVEL
def setup(
self,
max_size_mb: Optional[int] = None,
backup_count: Optional[int] = None,
log_level: Optional[str] = None
) -> logging.Logger:
"""
Setup or reconfigure logging.
Args:
max_size_mb: Maximum log file size in MB
backup_count: Number of backup files to keep
log_level: Logging level string
Returns:
The configured application logger
"""
# Update configuration if provided
if max_size_mb is not None:
self.max_size_mb = max_size_mb
if backup_count is not None:
self.backup_count = backup_count
if log_level is not None:
self.log_level = log_level
# Ensure config directory exists
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
# Get root logger
root_logger = logging.getLogger()
# Remove existing handler if present
if self.handler is not None:
root_logger.removeHandler(self.handler)
try:
self.handler.close()
except Exception:
pass
# Check if log needs manual rotation
self._check_rotation()
# Create new handler
max_bytes = self.max_size_mb * 1024 * 1024
self.handler = RotatingRichHandler(
filename=str(LOG_FILE),
maxBytes=max_bytes,
backupCount=self.backup_count,
encoding="utf-8"
)
self.handler.setLevel(logging.NOTSET)
root_logger.setLevel(logging.WARNING)
root_logger.addHandler(self.handler)
# Suppress noisy third-party loggers
for logger_name in [
"asyncio", "urllib3", "requests", "httpx",
"httpcore", "openai", "openrouter"
]:
logging.getLogger(logger_name).setLevel(logging.WARNING)
# Configure application logger
self.app_logger = logging.getLogger("oai_app")
level = VALID_LOG_LEVELS.get(self.log_level.lower(), logging.INFO)
self.app_logger.setLevel(level)
self.app_logger.propagate = True
return self.app_logger
def _check_rotation(self) -> None:
"""Check if log file needs rotation and perform if necessary."""
if not LOG_FILE.exists():
return
current_size = LOG_FILE.stat().st_size
max_bytes = self.max_size_mb * 1024 * 1024
if current_size >= max_bytes:
# Perform manual rotation
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = f"{LOG_FILE}.{timestamp}"
try:
shutil.move(str(LOG_FILE), backup_file)
except Exception:
pass
# Clean old backups
self._cleanup_old_backups()
def _cleanup_old_backups(self) -> None:
"""Remove old backup files exceeding the backup count."""
log_dir = LOG_FILE.parent
backup_pattern = f"{LOG_FILE.name}.*"
backups = sorted(glob.glob(str(log_dir / backup_pattern)))
while len(backups) > self.backup_count:
oldest = backups.pop(0)
try:
os.remove(oldest)
except Exception:
pass
def set_level(self, level: str) -> bool:
"""
Set the application log level.
Args:
level: Log level string (debug/info/warning/error/critical)
Returns:
True if level was set successfully, False otherwise
"""
level_lower = level.lower()
if level_lower not in VALID_LOG_LEVELS:
return False
self.log_level = level_lower
if self.app_logger:
self.app_logger.setLevel(VALID_LOG_LEVELS[level_lower])
return True
def get_logger(self) -> logging.Logger:
"""
Get the application logger, initializing if necessary.
Returns:
The application logger
"""
if self.app_logger is None:
self.setup()
return self.app_logger
# Global logging manager instance
_logging_manager = LoggingManager()
def setup_logging(
max_size_mb: Optional[int] = None,
backup_count: Optional[int] = None,
log_level: Optional[str] = None
) -> logging.Logger:
"""
Setup application logging.
This is the main entry point for configuring logging. Call this
early in application startup.
Args:
max_size_mb: Maximum log file size in MB
backup_count: Number of backup files to keep
log_level: Logging level string
Returns:
The configured application logger
"""
return _logging_manager.setup(max_size_mb, backup_count, log_level)
def get_logger() -> logging.Logger:
"""
Get the application logger.
Returns:
The application logger instance
"""
return _logging_manager.get_logger()
def set_log_level(level: str) -> bool:
"""
Set the application log level.
Args:
level: Log level string
Returns:
True if successful, False otherwise
"""
return _logging_manager.set_level(level)
def reload_logging(
max_size_mb: Optional[int] = None,
backup_count: Optional[int] = None,
log_level: Optional[str] = None
) -> logging.Logger:
"""
Reload logging configuration.
Useful when settings change at runtime.
Args:
max_size_mb: New maximum log file size
backup_count: New backup count
log_level: New log level
Returns:
The reconfigured logger
"""
return _logging_manager.setup(max_size_mb, backup_count, log_level)

134
pyproject.toml Normal file
View File

@@ -0,0 +1,134 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "oai"
version = "2.1.0"
description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application"
readme = "README.md"
license = {text = "MIT"}
authors = [
{name = "Rune", email = "rune@example.com"}
]
maintainers = [
{name = "Rune", email = "rune@example.com"}
]
keywords = [
"ai",
"chat",
"openrouter",
"cli",
"terminal",
"mcp",
"llm",
]
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: Console",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Utilities",
]
requires-python = ">=3.10"
dependencies = [
"anyio>=4.0.0",
"click>=8.0.0",
"httpx>=0.24.0",
"markdown-it-py>=3.0.0",
"openrouter>=0.0.19",
"packaging>=21.0",
"prompt-toolkit>=3.0.0",
"pyperclip>=1.8.0",
"requests>=2.28.0",
"rich>=13.0.0",
"typer>=0.9.0",
"mcp>=1.0.0",
]
[project.optional-dependencies]
dev = [
"pytest>=7.0.0",
"pytest-asyncio>=0.21.0",
"pytest-cov>=4.0.0",
"black>=23.0.0",
"isort>=5.12.0",
"mypy>=1.0.0",
"ruff>=0.1.0",
]
[project.urls]
Homepage = "https://iurl.no/oai"
Repository = "https://gitlab.pm/rune/oai"
Documentation = "https://iurl.no/oai"
"Bug Tracker" = "https://gitlab.pm/rune/oai/issues"
[project.scripts]
oai = "oai.cli:main"
[tool.setuptools]
packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.ui", "oai.utils"]
[tool.setuptools.package-data]
oai = ["py.typed"]
[tool.black]
line-length = 100
target-version = ["py310", "py311", "py312"]
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.mypy_cache
| \.pytest_cache
| \.venv
| build
| dist
)/
'''
[tool.isort]
profile = "black"
line_length = 100
skip_gitignore = true
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true
ignore_missing_imports = true
exclude = [
"build",
"dist",
".venv",
]
[tool.ruff]
line-length = 100
target-version = "py310"
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # Pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
]
ignore = [
"E501", # line too long (handled by black)
"B008", # do not perform function calls in argument defaults
"C901", # too complex
]
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
asyncio_mode = "auto"
addopts = "-v --tb=short"

View File

@@ -1,38 +1,26 @@
anyio==4.11.0
beautifulsoup4==4.14.2
charset-normalizer==3.4.4
click==8.3.1
docopt==0.6.2
h11==0.16.0
httpcore==1.0.9
httpx==0.28.1
idna==3.11
latex2mathml==3.78.1
loguru==0.7.3
markdown-it-py==4.0.0
markdown2==2.5.4
mdurl==0.1.2
natsort==8.4.0
openrouter==0.0.19
packaging==25.0
pipreqs==0.4.13
prompt-toolkit==3.0.52
Pygments==2.19.2
pyperclip==1.11.0
python-dateutil==2.9.0.post0
python-magic==0.4.27
PyYAML==6.0.3
requests==2.32.5
rich==14.2.0
shellingham==1.5.4
six==1.17.0
sniffio==1.3.1
soupsieve==2.8
svgwrite==1.4.3
tqdm==4.67.1
typer==0.20.0
typing-extensions==4.15.0
urllib3==2.5.0
wavedrom==2.0.3.post3
wcwidth==0.2.14
yarg==0.1.10
# oai.py v2.1.0-beta - Core Dependencies
anyio>=4.11.0
charset-normalizer>=3.4.4
click>=8.3.1
h11>=0.16.0
httpcore>=1.0.9
httpx>=0.28.1
idna>=3.11
markdown-it-py>=4.0.0
mdurl>=0.1.2
openrouter>=0.0.19
packaging>=25.0
prompt-toolkit>=3.0.52
Pygments>=2.19.2
pyperclip>=1.11.0
requests>=2.32.5
rich>=14.2.0
shellingham>=1.5.4
sniffio>=1.3.1
typer>=0.20.0
typing-extensions>=4.15.0
urllib3>=2.5.0
wcwidth>=0.2.14
# MCP (Model Context Protocol)
mcp>=1.25.0