Files
oai/oai/cli.py
2026-02-03 08:57:16 +01:00

720 lines
24 KiB
Python

"""
Main CLI entry point for oAI.
This module provides the command-line interface for the oAI chat application.
"""
import os
import sys
from pathlib import Path
from typing import Optional
import typer
from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.history import FileHistory
from rich.markdown import Markdown
from rich.panel import Panel
from oai import __version__
from oai.commands import register_all_commands, registry
from oai.commands.registry import CommandContext, CommandStatus
from oai.config.database import Database
from oai.config.settings import Settings
from oai.constants import (
APP_NAME,
APP_URL,
APP_VERSION,
CONFIG_DIR,
HISTORY_FILE,
VALID_COMMANDS,
)
from oai.core.client import AIClient
from oai.core.session import ChatSession
from oai.mcp.manager import MCPManager
from oai.providers.base import UsageStats
from oai.providers.openrouter import OpenRouterProvider
from oai.ui.console import (
clear_screen,
console,
display_panel,
print_error,
print_info,
print_success,
print_warning,
)
from oai.ui.tables import create_model_table, display_paginated_table
from oai.utils.logging import LoggingManager, get_logger
# Create Typer app
app = typer.Typer(
name="oai",
help=f"oAI - OpenRouter AI Chat Client\n\nVersion: {APP_VERSION}",
add_completion=False,
epilog="For more information, visit: " + APP_URL,
)
@app.callback(invoke_without_command=True)
def main_callback(
ctx: typer.Context,
version_flag: bool = typer.Option(
False,
"--version",
"-v",
help="Show version information",
is_flag=True,
),
) -> None:
"""Main callback to handle global options."""
# Show version with update check if --version flag
if version_flag:
version_info = check_for_updates(APP_VERSION)
console.print(version_info)
raise typer.Exit()
# Show version with update check when --help is requested
if "--help" in sys.argv or "-h" in sys.argv:
version_info = check_for_updates(APP_VERSION)
console.print(f"\n{version_info}\n")
# Continue to subcommand if provided
if ctx.invoked_subcommand is None:
return
def check_for_updates(current_version: str) -> str:
"""Check for available updates."""
import requests
from packaging import version as pkg_version
try:
response = requests.get(
"https://gitlab.pm/api/v1/repos/rune/oai/releases/latest",
headers={"Content-Type": "application/json"},
timeout=1.0,
)
response.raise_for_status()
data = response.json()
version_online = data.get("tag_name", "").lstrip("v")
if not version_online:
return f"[bold green]oAI version {current_version}[/]"
current = pkg_version.parse(current_version)
latest = pkg_version.parse(version_online)
if latest > current:
return (
f"[bold green]oAI version {current_version}[/] "
f"[bold red](Update available: {current_version}{version_online})[/]"
)
return f"[bold green]oAI version {current_version} (up to date)[/]"
except Exception:
return f"[bold green]oAI version {current_version}[/]"
def show_welcome(settings: Settings, version_info: str) -> None:
"""Display welcome message."""
console.print(Panel.fit(
f"{version_info}\n\n"
"[bold cyan]Commands:[/] /help for commands, /model to select model\n"
"[bold cyan]MCP:[/] /mcp on to enable file/database access\n"
"[bold cyan]Exit:[/] Type 'exit', 'quit', or 'bye'",
title=f"[bold green]Welcome to {APP_NAME}[/]",
border_style="green",
))
def select_model(client: AIClient, search_term: Optional[str] = None) -> Optional[dict]:
"""Display model selection interface."""
try:
models = client.provider.get_raw_models()
if not models:
print_error("No models available")
return None
# Filter by search term if provided
if search_term:
search_lower = search_term.lower()
models = [m for m in models if search_lower in m.get("id", "").lower()]
if not models:
print_error(f"No models found matching '{search_term}'")
return None
# Create and display table
table = create_model_table(models)
display_paginated_table(
table,
f"[bold green]Available Models ({len(models)})[/]",
)
# Prompt for selection
console.print("")
try:
choice = input("Enter model number (or press Enter to cancel): ").strip()
except (EOFError, KeyboardInterrupt):
return None
if not choice:
return None
try:
index = int(choice) - 1
if 0 <= index < len(models):
selected = models[index]
print_success(f"Selected model: {selected['id']}")
return selected
except ValueError:
pass
print_error("Invalid selection")
return None
except Exception as e:
print_error(f"Failed to fetch models: {e}")
return None
def run_chat_loop(
session: ChatSession,
prompt_session: PromptSession,
settings: Settings,
) -> None:
"""Run the main chat loop."""
logger = get_logger()
mcp_manager = session.mcp_manager
while True:
try:
# Build prompt prefix
prefix = "You> "
if mcp_manager and mcp_manager.enabled:
if mcp_manager.mode == "files":
if mcp_manager.write_enabled:
prefix = "[🔧✍️ MCP: Files+Write] You> "
else:
prefix = "[🔧 MCP: Files] You> "
elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None:
prefix = f"[🗄️ MCP: DB #{mcp_manager.selected_db_index + 1}] You> "
# Get user input
user_input = prompt_session.prompt(
prefix,
auto_suggest=AutoSuggestFromHistory(),
).strip()
if not user_input:
continue
# Handle escape sequence
if user_input.startswith("//"):
user_input = user_input[1:]
# Check for exit
if user_input.lower() in ["exit", "quit", "bye"]:
console.print(
f"\n[bold yellow]Goodbye![/]\n"
f"[dim]Session: {session.stats.total_tokens:,} tokens, "
f"${session.stats.total_cost:.4f}[/]"
)
logger.info(
f"Session ended. Messages: {session.stats.message_count}, "
f"Tokens: {session.stats.total_tokens}, "
f"Cost: ${session.stats.total_cost:.4f}"
)
return
# Check for unknown commands
if user_input.startswith("/"):
cmd_word = user_input.split()[0].lower()
if not registry.is_command(user_input):
# Check if it's a valid command prefix
is_valid = any(cmd_word.startswith(cmd) for cmd in VALID_COMMANDS)
if not is_valid:
print_error(f"Unknown command: {cmd_word}")
print_info("Type /help to see available commands.")
continue
# Try to execute as command
context = session.get_context()
result = registry.execute(user_input, context)
if result:
# Update session state from context
session.memory_enabled = context.memory_enabled
session.memory_start_index = context.memory_start_index
session.online_enabled = context.online_enabled
session.middle_out_enabled = context.middle_out_enabled
session.session_max_token = context.session_max_token
session.current_index = context.current_index
session.system_prompt = context.session_system_prompt
if result.status == CommandStatus.EXIT:
return
# Handle special results
if result.data:
# Retry - resend last prompt
if "retry_prompt" in result.data:
user_input = result.data["retry_prompt"]
# Fall through to send message
# Paste - send clipboard content
elif "paste_prompt" in result.data:
user_input = result.data["paste_prompt"]
# Fall through to send message
# Model selection
elif "show_model_selector" in result.data:
search = result.data.get("search", "")
model = select_model(session.client, search if search else None)
if model:
session.set_model(model)
# If this came from /config model, also save as default
if result.data.get("set_as_default"):
settings.set_default_model(model["id"])
print_success(f"Default model set to: {model['id']}")
continue
# Load conversation
elif "load_conversation" in result.data:
history = result.data.get("history", [])
session.history.clear()
from oai.core.session import HistoryEntry
for entry in history:
session.history.append(HistoryEntry(
prompt=entry.get("prompt", ""),
response=entry.get("response", ""),
prompt_tokens=entry.get("prompt_tokens", 0),
completion_tokens=entry.get("completion_tokens", 0),
msg_cost=entry.get("msg_cost", 0.0),
))
session.current_index = len(session.history) - 1
continue
else:
# Normal command completed
continue
else:
# Command completed with no special data
continue
# Ensure model is selected
if not session.selected_model:
print_warning("Please select a model first with /model")
continue
# Send message
stream = settings.stream_enabled
if mcp_manager and mcp_manager.enabled:
tools = session.get_mcp_tools()
if tools:
stream = False # Disable streaming with tools
if stream:
console.print(
"[bold green]Streaming response...[/] "
"[dim](Press Ctrl+C to cancel)[/]"
)
if session.online_enabled:
console.print("[dim cyan]🌐 Online mode active[/]")
console.print("")
try:
response_text, usage, response_time = session.send_message(
user_input,
stream=stream,
)
except Exception as e:
print_error(f"Error: {e}")
logger.error(f"Message error: {e}")
continue
if not response_text:
print_error("No response received")
continue
# Display non-streaming response
if not stream:
console.print()
display_panel(
Markdown(response_text),
title="[bold green]AI Response[/]",
border_style="green",
)
# Calculate cost and tokens
cost = 0.0
tokens = 0
estimated = False
if usage and (usage.prompt_tokens > 0 or usage.completion_tokens > 0):
tokens = usage.total_tokens
if usage.total_cost_usd:
cost = usage.total_cost_usd
else:
cost = session.client.estimate_cost(
session.selected_model["id"],
usage.prompt_tokens,
usage.completion_tokens,
)
else:
# Estimate tokens when usage not available (streaming fallback)
# Rough estimate: ~4 characters per token for English text
est_input_tokens = len(user_input) // 4 + 1
est_output_tokens = len(response_text) // 4 + 1
tokens = est_input_tokens + est_output_tokens
cost = session.client.estimate_cost(
session.selected_model["id"],
est_input_tokens,
est_output_tokens,
)
# Create estimated usage for session tracking
usage = UsageStats(
prompt_tokens=est_input_tokens,
completion_tokens=est_output_tokens,
total_tokens=tokens,
)
estimated = True
# Add to history
session.add_to_history(user_input, response_text, usage, cost)
# Display metrics
est_marker = "~" if estimated else ""
context_info = ""
if session.memory_enabled:
context_count = len(session.history) - session.memory_start_index
if context_count > 1:
context_info = f", Context: {context_count} msg(s)"
else:
context_info = ", Memory: OFF"
online_emoji = " 🌐" if session.online_enabled else ""
mcp_emoji = ""
if mcp_manager and mcp_manager.enabled:
if mcp_manager.mode == "files":
mcp_emoji = " 🔧"
elif mcp_manager.mode == "database":
mcp_emoji = " 🗄️"
console.print(
f"\n[dim blue]📊 {est_marker}{tokens} tokens | {est_marker}${cost:.4f} | {response_time:.2f}s"
f"{context_info}{online_emoji}{mcp_emoji} | "
f"Session: {est_marker}{session.stats.total_tokens:,} tokens | "
f"{est_marker}${session.stats.total_cost:.4f}[/]"
)
# Check warnings
warnings = session.check_warnings()
for warning in warnings:
print_warning(warning)
# Offer to copy
console.print("")
try:
from oai.ui.prompts import prompt_copy_response
prompt_copy_response(response_text)
except Exception:
pass
console.print("")
except KeyboardInterrupt:
console.print("\n[dim]Input cancelled[/]")
continue
except EOFError:
console.print("\n[bold yellow]Goodbye![/]")
return
@app.command()
def chat(
model: Optional[str] = typer.Option(
None,
"--model",
"-m",
help="Model ID to use",
),
system: Optional[str] = typer.Option(
None,
"--system",
"-s",
help="System prompt",
),
online: bool = typer.Option(
False,
"--online",
"-o",
help="Enable online mode",
),
mcp: bool = typer.Option(
False,
"--mcp",
help="Enable MCP server",
),
) -> None:
"""Start an interactive chat session."""
# Setup logging
logging_manager = LoggingManager()
logging_manager.setup()
logger = get_logger()
# Clear screen
clear_screen()
# Load settings
settings = Settings.load()
# Check API key
if not settings.api_key:
print_error("No API key configured")
print_info("Run: oai --config api to set your API key")
raise typer.Exit(1)
# Initialize client
try:
client = AIClient(
api_key=settings.api_key,
base_url=settings.base_url,
)
except Exception as e:
print_error(f"Failed to initialize client: {e}")
raise typer.Exit(1)
# Register commands
register_all_commands()
# Check for updates and show welcome
version_info = check_for_updates(APP_VERSION)
show_welcome(settings, version_info)
# Initialize MCP manager
mcp_manager = MCPManager()
if mcp:
result = mcp_manager.enable()
if result["success"]:
print_success("MCP enabled")
else:
print_warning(f"MCP: {result.get('error', 'Failed to enable')}")
# Create session
session = ChatSession(
client=client,
settings=settings,
mcp_manager=mcp_manager,
)
# Set system prompt
if system:
session.system_prompt = system
print_info(f"System prompt: {system}")
# Set online mode
if online:
session.online_enabled = True
print_info("Online mode enabled")
# Select model
if model:
raw_model = client.get_raw_model(model)
if raw_model:
session.set_model(raw_model)
else:
print_warning(f"Model '{model}' not found")
elif settings.default_model:
raw_model = client.get_raw_model(settings.default_model)
if raw_model:
session.set_model(raw_model)
else:
print_warning(f"Default model '{settings.default_model}' not available")
# Setup prompt session
HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True)
prompt_session = PromptSession(
history=FileHistory(str(HISTORY_FILE)),
)
# Run chat loop
run_chat_loop(session, prompt_session, settings)
@app.command()
def config(
setting: Optional[str] = typer.Argument(
None,
help="Setting to configure (api, url, model, system, stream, costwarning, maxtoken, online, log, loglevel)",
),
value: Optional[str] = typer.Argument(
None,
help="Value to set",
),
) -> None:
"""View or modify configuration settings."""
settings = Settings.load()
if not setting:
# Show all settings
from rich.table import Table
from oai.constants import DEFAULT_SYSTEM_PROMPT
table = Table("Setting", "Value", show_header=True, header_style="bold magenta")
table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set")
table.add_row("Base URL", settings.base_url)
table.add_row("Default Model", settings.default_model or "Not set")
# Show system prompt status
if settings.default_system_prompt is None:
system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
elif settings.default_system_prompt == "":
system_prompt_display = "[blank]"
else:
system_prompt_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt
table.add_row("System Prompt", system_prompt_display)
table.add_row("Streaming", "on" if settings.stream_enabled else "off")
table.add_row("Cost Warning", f"${settings.cost_warning_threshold:.4f}")
table.add_row("Max Tokens", str(settings.max_tokens))
table.add_row("Default Online", "on" if settings.default_online_mode else "off")
table.add_row("Log Level", settings.log_level)
display_panel(table, title="[bold green]Configuration[/]")
return
setting = setting.lower()
if setting == "api":
if value:
settings.set_api_key(value)
else:
from oai.ui.prompts import prompt_input
new_key = prompt_input("Enter API key", password=True)
if new_key:
settings.set_api_key(new_key)
print_success("API key updated")
elif setting == "url":
settings.set_base_url(value or "https://openrouter.ai/api/v1")
print_success(f"Base URL set to: {settings.base_url}")
elif setting == "model":
if value:
settings.set_default_model(value)
print_success(f"Default model set to: {value}")
else:
print_info(f"Current default model: {settings.default_model or 'Not set'}")
elif setting == "system":
from oai.constants import DEFAULT_SYSTEM_PROMPT
if value:
# Decode escape sequences like \n for newlines
value = value.encode().decode('unicode_escape')
settings.set_default_system_prompt(value)
if value:
print_success(f"Default system prompt set to: {value}")
else:
print_success("Default system prompt set to blank.")
else:
if settings.default_system_prompt is None:
print_info(f"Using hardcoded default: {DEFAULT_SYSTEM_PROMPT[:60]}...")
elif settings.default_system_prompt == "":
print_info("System prompt: [blank]")
else:
print_info(f"System prompt: {settings.default_system_prompt}")
elif setting == "stream":
if value and value.lower() in ["on", "off"]:
settings.set_stream_enabled(value.lower() == "on")
print_success(f"Streaming {'enabled' if settings.stream_enabled else 'disabled'}")
else:
print_info("Usage: oai config stream [on|off]")
elif setting == "costwarning":
if value:
try:
threshold = float(value)
settings.set_cost_warning_threshold(threshold)
print_success(f"Cost warning threshold set to: ${threshold:.4f}")
except ValueError:
print_error("Please enter a valid number")
else:
print_info(f"Current threshold: ${settings.cost_warning_threshold:.4f}")
elif setting == "maxtoken":
if value:
try:
max_tok = int(value)
settings.set_max_tokens(max_tok)
print_success(f"Max tokens set to: {max_tok}")
except ValueError:
print_error("Please enter a valid number")
else:
print_info(f"Current max tokens: {settings.max_tokens}")
elif setting == "online":
if value and value.lower() in ["on", "off"]:
settings.set_default_online_mode(value.lower() == "on")
print_success(f"Default online mode {'enabled' if settings.default_online_mode else 'disabled'}")
else:
print_info("Usage: oai config online [on|off]")
elif setting == "loglevel":
valid_levels = ["debug", "info", "warning", "error", "critical"]
if value and value.lower() in valid_levels:
settings.set_log_level(value.lower())
print_success(f"Log level set to: {value.lower()}")
else:
print_info(f"Valid levels: {', '.join(valid_levels)}")
else:
print_error(f"Unknown setting: {setting}")
@app.command()
def version() -> None:
"""Show version information."""
version_info = check_for_updates(APP_VERSION)
console.print(version_info)
@app.command()
def credits() -> None:
"""Check account credits."""
settings = Settings.load()
if not settings.api_key:
print_error("No API key configured")
raise typer.Exit(1)
client = AIClient(api_key=settings.api_key, base_url=settings.base_url)
credits_data = client.get_credits()
if not credits_data:
print_error("Failed to fetch credits")
raise typer.Exit(1)
from rich.table import Table
table = Table("Metric", "Value", show_header=True, header_style="bold magenta")
table.add_row("Total Credits", credits_data.get("total_credits_formatted", "N/A"))
table.add_row("Used Credits", credits_data.get("used_credits_formatted", "N/A"))
table.add_row("Credits Left", credits_data.get("credits_left_formatted", "N/A"))
display_panel(table, title="[bold green]Account Credits[/]")
def main() -> None:
"""Main entry point."""
# Default to 'chat' command if no arguments provided
if len(sys.argv) == 1:
sys.argv.append("chat")
app()
if __name__ == "__main__":
main()