Final release of version 2.1. Headlights: ### Core Features - 🤖 Interactive chat with 300+ AI models via OpenRouter - 🔍 Model selection with search and filtering - 💾 Conversation save/load/export (Markdown, JSON, HTML) - 📎 File attachments (images, PDFs, code files) - 💰 Real-time cost tracking and credit monitoring - 🎨 Rich terminal UI with syntax highlighting - 📝 Persistent command history with search (Ctrl+R) - 🌐 Online mode (web search capabilities) - 🧠 Conversation memory toggle ### MCP Integration - 🔧 **File Mode**: AI can read, search, and list local files - Automatic .gitignore filtering - Virtual environment exclusion - Large file handling (auto-truncates >50KB) - ✍️ **Write Mode**: AI can modify files with permission - Create, edit, delete files - Move, copy, organize files - Always requires explicit opt-in - 🗄️ **Database Mode**: AI can query SQLite databases - Read-only access (safe) - Schema inspection - Full SQL query support Reviewed-on: #2 Co-authored-by: Rune Olsen <rune@rune.pm> Co-committed-by: Rune Olsen <rune@rune.pm>
624 lines
20 KiB
Python
624 lines
20 KiB
Python
"""
|
|
OpenRouter provider implementation.
|
|
|
|
This module implements the AIProvider interface for OpenRouter,
|
|
supporting chat completions, streaming, and function calling.
|
|
"""
|
|
|
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
|
|
|
import requests
|
|
from openrouter import OpenRouter
|
|
|
|
from oai.constants import APP_NAME, APP_URL, DEFAULT_BASE_URL
|
|
from oai.providers.base import (
|
|
AIProvider,
|
|
ChatMessage,
|
|
ChatResponse,
|
|
ChatResponseChoice,
|
|
ModelInfo,
|
|
ProviderCapabilities,
|
|
StreamChunk,
|
|
ToolCall,
|
|
ToolFunction,
|
|
UsageStats,
|
|
)
|
|
from oai.utils.logging import get_logger
|
|
|
|
|
|
class OpenRouterProvider(AIProvider):
|
|
"""
|
|
OpenRouter API provider implementation.
|
|
|
|
Provides access to multiple AI models through OpenRouter's unified API,
|
|
supporting chat completions, streaming responses, and function calling.
|
|
|
|
Attributes:
|
|
client: The underlying OpenRouter client
|
|
_models_cache: Cached list of available models
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
base_url: Optional[str] = None,
|
|
app_name: str = APP_NAME,
|
|
app_url: str = APP_URL,
|
|
):
|
|
"""
|
|
Initialize the OpenRouter provider.
|
|
|
|
Args:
|
|
api_key: OpenRouter API key
|
|
base_url: Optional custom base URL
|
|
app_name: Application name for API headers
|
|
app_url: Application URL for API headers
|
|
"""
|
|
super().__init__(api_key, base_url or DEFAULT_BASE_URL)
|
|
self.app_name = app_name
|
|
self.app_url = app_url
|
|
self.client = OpenRouter(api_key=api_key)
|
|
self._models_cache: Optional[List[ModelInfo]] = None
|
|
self._raw_models_cache: Optional[List[Dict[str, Any]]] = None
|
|
|
|
self.logger = get_logger()
|
|
self.logger.info(f"OpenRouter provider initialized with base URL: {self.base_url}")
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""Get the provider name."""
|
|
return "OpenRouter"
|
|
|
|
@property
|
|
def capabilities(self) -> ProviderCapabilities:
|
|
"""Get provider capabilities."""
|
|
return ProviderCapabilities(
|
|
streaming=True,
|
|
tools=True,
|
|
images=True,
|
|
online=True,
|
|
max_context=2000000, # Claude models support up to 200k
|
|
)
|
|
|
|
def _get_headers(self) -> Dict[str, str]:
|
|
"""Get standard HTTP headers for API requests."""
|
|
headers = {
|
|
"HTTP-Referer": self.app_url,
|
|
"X-Title": self.app_name,
|
|
}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
return headers
|
|
|
|
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
|
|
"""
|
|
Parse raw model data into ModelInfo.
|
|
|
|
Args:
|
|
model_data: Raw model data from API
|
|
|
|
Returns:
|
|
Parsed ModelInfo object
|
|
"""
|
|
architecture = model_data.get("architecture", {})
|
|
pricing_data = model_data.get("pricing", {})
|
|
|
|
# Parse pricing (convert from string to float if needed)
|
|
pricing = {}
|
|
for key in ["prompt", "completion"]:
|
|
value = pricing_data.get(key)
|
|
if value is not None:
|
|
try:
|
|
# Convert from per-token to per-million-tokens
|
|
pricing[key] = float(value) * 1_000_000
|
|
except (ValueError, TypeError):
|
|
pricing[key] = 0.0
|
|
|
|
return ModelInfo(
|
|
id=model_data.get("id", ""),
|
|
name=model_data.get("name", model_data.get("id", "")),
|
|
description=model_data.get("description", ""),
|
|
context_length=model_data.get("context_length", 0),
|
|
pricing=pricing,
|
|
supported_parameters=model_data.get("supported_parameters", []),
|
|
input_modalities=architecture.get("input_modalities", ["text"]),
|
|
output_modalities=architecture.get("output_modalities", ["text"]),
|
|
)
|
|
|
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
|
"""
|
|
Fetch available models from OpenRouter.
|
|
|
|
Args:
|
|
filter_text_only: If True, exclude video-only models
|
|
|
|
Returns:
|
|
List of available models
|
|
|
|
Raises:
|
|
Exception: If API request fails
|
|
"""
|
|
if self._models_cache is not None:
|
|
return self._models_cache
|
|
|
|
try:
|
|
response = requests.get(
|
|
f"{self.base_url}/models",
|
|
headers=self._get_headers(),
|
|
timeout=10,
|
|
)
|
|
response.raise_for_status()
|
|
|
|
raw_models = response.json().get("data", [])
|
|
self._raw_models_cache = raw_models
|
|
|
|
models = []
|
|
for model_data in raw_models:
|
|
# Optionally filter out video-only models
|
|
if filter_text_only:
|
|
modalities = model_data.get("modalities", [])
|
|
if modalities and "video" in modalities and "text" not in modalities:
|
|
continue
|
|
|
|
models.append(self._parse_model(model_data))
|
|
|
|
self._models_cache = models
|
|
self.logger.info(f"Fetched {len(models)} models from OpenRouter")
|
|
return models
|
|
|
|
except requests.RequestException as e:
|
|
self.logger.error(f"Failed to fetch models: {e}")
|
|
raise
|
|
|
|
def get_raw_models(self) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get raw model data as returned by the API.
|
|
|
|
Useful for accessing provider-specific fields not in ModelInfo.
|
|
|
|
Returns:
|
|
List of raw model dictionaries
|
|
"""
|
|
if self._raw_models_cache is None:
|
|
self.list_models()
|
|
return self._raw_models_cache or []
|
|
|
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
|
"""
|
|
Get information about a specific model.
|
|
|
|
Args:
|
|
model_id: The model identifier
|
|
|
|
Returns:
|
|
Model information or None if not found
|
|
"""
|
|
models = self.list_models()
|
|
for model in models:
|
|
if model.id == model_id:
|
|
return model
|
|
return None
|
|
|
|
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get raw model data for a specific model.
|
|
|
|
Args:
|
|
model_id: The model identifier
|
|
|
|
Returns:
|
|
Raw model dictionary or None if not found
|
|
"""
|
|
raw_models = self.get_raw_models()
|
|
for model in raw_models:
|
|
if model.get("id") == model_id:
|
|
return model
|
|
return None
|
|
|
|
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Convert ChatMessage objects to API format.
|
|
|
|
Args:
|
|
messages: List of ChatMessage objects
|
|
|
|
Returns:
|
|
List of message dictionaries for the API
|
|
"""
|
|
return [msg.to_dict() for msg in messages]
|
|
|
|
def _parse_usage(self, usage_data: Any) -> Optional[UsageStats]:
|
|
"""
|
|
Parse usage data from API response.
|
|
|
|
Args:
|
|
usage_data: Raw usage data from API
|
|
|
|
Returns:
|
|
Parsed UsageStats or None
|
|
"""
|
|
if not usage_data:
|
|
return None
|
|
|
|
# Handle both attribute and dict access
|
|
prompt_tokens = 0
|
|
completion_tokens = 0
|
|
total_cost = None
|
|
|
|
if hasattr(usage_data, "prompt_tokens"):
|
|
prompt_tokens = getattr(usage_data, "prompt_tokens", 0) or 0
|
|
elif isinstance(usage_data, dict):
|
|
prompt_tokens = usage_data.get("prompt_tokens", 0) or 0
|
|
|
|
if hasattr(usage_data, "completion_tokens"):
|
|
completion_tokens = getattr(usage_data, "completion_tokens", 0) or 0
|
|
elif isinstance(usage_data, dict):
|
|
completion_tokens = usage_data.get("completion_tokens", 0) or 0
|
|
|
|
# Try alternative naming (input_tokens/output_tokens)
|
|
if prompt_tokens == 0:
|
|
if hasattr(usage_data, "input_tokens"):
|
|
prompt_tokens = getattr(usage_data, "input_tokens", 0) or 0
|
|
elif isinstance(usage_data, dict):
|
|
prompt_tokens = usage_data.get("input_tokens", 0) or 0
|
|
|
|
if completion_tokens == 0:
|
|
if hasattr(usage_data, "output_tokens"):
|
|
completion_tokens = getattr(usage_data, "output_tokens", 0) or 0
|
|
elif isinstance(usage_data, dict):
|
|
completion_tokens = usage_data.get("output_tokens", 0) or 0
|
|
|
|
# Get cost if available
|
|
if hasattr(usage_data, "total_cost_usd"):
|
|
total_cost = getattr(usage_data, "total_cost_usd", None)
|
|
elif isinstance(usage_data, dict):
|
|
total_cost = usage_data.get("total_cost_usd")
|
|
|
|
return UsageStats(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
total_cost_usd=float(total_cost) if total_cost else None,
|
|
)
|
|
|
|
def _parse_tool_calls(self, tool_calls_data: Any) -> Optional[List[ToolCall]]:
|
|
"""
|
|
Parse tool calls from API response.
|
|
|
|
Args:
|
|
tool_calls_data: Raw tool calls data
|
|
|
|
Returns:
|
|
List of ToolCall objects or None
|
|
"""
|
|
if not tool_calls_data:
|
|
return None
|
|
|
|
tool_calls = []
|
|
for tc in tool_calls_data:
|
|
# Handle both attribute and dict access
|
|
if hasattr(tc, "id"):
|
|
tc_id = tc.id
|
|
tc_type = getattr(tc, "type", "function")
|
|
func = tc.function
|
|
func_name = func.name
|
|
func_args = func.arguments
|
|
else:
|
|
tc_id = tc.get("id", "")
|
|
tc_type = tc.get("type", "function")
|
|
func = tc.get("function", {})
|
|
func_name = func.get("name", "")
|
|
func_args = func.get("arguments", "{}")
|
|
|
|
tool_calls.append(
|
|
ToolCall(
|
|
id=tc_id,
|
|
type=tc_type,
|
|
function=ToolFunction(name=func_name, arguments=func_args),
|
|
)
|
|
)
|
|
|
|
return tool_calls if tool_calls else None
|
|
|
|
def _parse_response(self, response: Any) -> ChatResponse:
|
|
"""
|
|
Parse API response into ChatResponse.
|
|
|
|
Args:
|
|
response: Raw API response
|
|
|
|
Returns:
|
|
Parsed ChatResponse
|
|
"""
|
|
choices = []
|
|
for choice in response.choices:
|
|
msg = choice.message
|
|
message = ChatMessage(
|
|
role=msg.role if hasattr(msg, "role") else "assistant",
|
|
content=msg.content if hasattr(msg, "content") else None,
|
|
tool_calls=self._parse_tool_calls(
|
|
getattr(msg, "tool_calls", None)
|
|
),
|
|
)
|
|
choices.append(
|
|
ChatResponseChoice(
|
|
index=choice.index if hasattr(choice, "index") else 0,
|
|
message=message,
|
|
finish_reason=getattr(choice, "finish_reason", None),
|
|
)
|
|
)
|
|
|
|
return ChatResponse(
|
|
id=response.id if hasattr(response, "id") else "",
|
|
choices=choices,
|
|
usage=self._parse_usage(getattr(response, "usage", None)),
|
|
model=getattr(response, "model", None),
|
|
created=getattr(response, "created", None),
|
|
)
|
|
|
|
def chat(
|
|
self,
|
|
model: str,
|
|
messages: List[ChatMessage],
|
|
stream: bool = False,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
tool_choice: Optional[str] = None,
|
|
transforms: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
|
"""
|
|
Send a chat completion request to OpenRouter.
|
|
|
|
Args:
|
|
model: Model ID to use
|
|
messages: List of chat messages
|
|
stream: Whether to stream the response
|
|
max_tokens: Maximum tokens in response
|
|
temperature: Sampling temperature (0-2)
|
|
tools: List of tool definitions for function calling
|
|
tool_choice: How to handle tool selection ("auto", "none", etc.)
|
|
transforms: List of transforms (e.g., ["middle-out"])
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
|
"""
|
|
# Build request parameters
|
|
params: Dict[str, Any] = {
|
|
"model": model,
|
|
"messages": self._convert_messages(messages),
|
|
"stream": stream,
|
|
"http_headers": self._get_headers(),
|
|
}
|
|
|
|
# Request usage stats in streaming responses
|
|
if stream:
|
|
params["stream_options"] = {"include_usage": True}
|
|
|
|
if max_tokens is not None:
|
|
params["max_tokens"] = max_tokens
|
|
|
|
if temperature is not None:
|
|
params["temperature"] = temperature
|
|
|
|
if tools:
|
|
params["tools"] = tools
|
|
params["tool_choice"] = tool_choice or "auto"
|
|
|
|
if transforms:
|
|
params["transforms"] = transforms
|
|
|
|
# Add any additional parameters
|
|
params.update(kwargs)
|
|
|
|
self.logger.debug(f"Sending chat request to model {model}")
|
|
|
|
try:
|
|
response = self.client.chat.send(**params)
|
|
|
|
if stream:
|
|
return self._stream_response(response)
|
|
else:
|
|
return self._parse_response(response)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Chat request failed: {e}")
|
|
raise
|
|
|
|
def _stream_response(self, response: Any) -> Iterator[StreamChunk]:
|
|
"""
|
|
Process a streaming response.
|
|
|
|
Args:
|
|
response: Streaming response from API
|
|
|
|
Yields:
|
|
StreamChunk objects
|
|
"""
|
|
last_usage = None
|
|
|
|
try:
|
|
for chunk in response:
|
|
# Check for errors
|
|
if hasattr(chunk, "error") and chunk.error:
|
|
yield StreamChunk(
|
|
id=getattr(chunk, "id", ""),
|
|
error=chunk.error.message if hasattr(chunk.error, "message") else str(chunk.error),
|
|
)
|
|
return
|
|
|
|
# Extract delta content
|
|
delta_content = None
|
|
finish_reason = None
|
|
|
|
if hasattr(chunk, "choices") and chunk.choices:
|
|
choice = chunk.choices[0]
|
|
if hasattr(choice, "delta"):
|
|
delta = choice.delta
|
|
if hasattr(delta, "content") and delta.content:
|
|
delta_content = delta.content
|
|
finish_reason = getattr(choice, "finish_reason", None)
|
|
|
|
# Track usage from last chunk
|
|
if hasattr(chunk, "usage") and chunk.usage:
|
|
last_usage = self._parse_usage(chunk.usage)
|
|
|
|
yield StreamChunk(
|
|
id=getattr(chunk, "id", ""),
|
|
delta_content=delta_content,
|
|
finish_reason=finish_reason,
|
|
usage=last_usage if finish_reason else None,
|
|
)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Stream error: {e}")
|
|
yield StreamChunk(id="", error=str(e))
|
|
|
|
async def chat_async(
|
|
self,
|
|
model: str,
|
|
messages: List[ChatMessage],
|
|
stream: bool = False,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
tool_choice: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
|
"""
|
|
Send an async chat completion request.
|
|
|
|
Note: Currently wraps the sync implementation.
|
|
TODO: Implement true async support when OpenRouter SDK supports it.
|
|
|
|
Args:
|
|
model: Model ID to use
|
|
messages: List of chat messages
|
|
stream: Whether to stream the response
|
|
max_tokens: Maximum tokens in response
|
|
temperature: Sampling temperature
|
|
tools: List of tool definitions
|
|
tool_choice: Tool selection mode
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
ChatResponse for non-streaming, AsyncIterator for streaming
|
|
"""
|
|
# For now, use sync implementation
|
|
# TODO: Add true async when SDK supports it
|
|
result = self.chat(
|
|
model=model,
|
|
messages=messages,
|
|
stream=stream,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
**kwargs,
|
|
)
|
|
|
|
if stream and isinstance(result, Iterator):
|
|
# Convert sync iterator to async
|
|
async def async_iter() -> AsyncIterator[StreamChunk]:
|
|
for chunk in result:
|
|
yield chunk
|
|
|
|
return async_iter()
|
|
|
|
return result
|
|
|
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get OpenRouter account credit information.
|
|
|
|
Returns:
|
|
Dict with credit info:
|
|
- total_credits: Total credits purchased
|
|
- used_credits: Credits used
|
|
- credits_left: Remaining credits
|
|
|
|
Raises:
|
|
Exception: If API request fails
|
|
"""
|
|
if not self.api_key:
|
|
return None
|
|
|
|
try:
|
|
response = requests.get(
|
|
f"{self.base_url}/credits",
|
|
headers=self._get_headers(),
|
|
timeout=10,
|
|
)
|
|
response.raise_for_status()
|
|
|
|
data = response.json().get("data", {})
|
|
total_credits = float(data.get("total_credits", 0))
|
|
total_usage = float(data.get("total_usage", 0))
|
|
credits_left = total_credits - total_usage
|
|
|
|
return {
|
|
"total_credits": total_credits,
|
|
"used_credits": total_usage,
|
|
"credits_left": credits_left,
|
|
"total_credits_formatted": f"${total_credits:.2f}",
|
|
"used_credits_formatted": f"${total_usage:.2f}",
|
|
"credits_left_formatted": f"${credits_left:.2f}",
|
|
}
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to fetch credits: {e}")
|
|
return None
|
|
|
|
def clear_cache(self) -> None:
|
|
"""Clear the models cache to force a refresh."""
|
|
self._models_cache = None
|
|
self._raw_models_cache = None
|
|
self.logger.debug("Models cache cleared")
|
|
|
|
def get_effective_model_id(self, model_id: str, online_enabled: bool) -> str:
|
|
"""
|
|
Get the effective model ID with online suffix if needed.
|
|
|
|
Args:
|
|
model_id: Base model ID
|
|
online_enabled: Whether online mode is enabled
|
|
|
|
Returns:
|
|
Model ID with :online suffix if applicable
|
|
"""
|
|
if online_enabled and not model_id.endswith(":online"):
|
|
return f"{model_id}:online"
|
|
return model_id
|
|
|
|
def estimate_cost(
|
|
self,
|
|
model_id: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
) -> float:
|
|
"""
|
|
Estimate the cost for a completion.
|
|
|
|
Args:
|
|
model_id: Model ID
|
|
input_tokens: Number of input tokens
|
|
output_tokens: Number of output tokens
|
|
|
|
Returns:
|
|
Estimated cost in USD
|
|
"""
|
|
model = self.get_model(model_id)
|
|
if model and model.pricing:
|
|
input_cost = model.pricing.get("prompt", 0) * input_tokens / 1_000_000
|
|
output_cost = model.pricing.get("completion", 0) * output_tokens / 1_000_000
|
|
return input_cost + output_cost
|
|
|
|
# Fallback to default pricing if model not found
|
|
from oai.constants import MODEL_PRICING
|
|
|
|
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
|
|
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
|
|
return input_cost + output_cost
|