Bug fixes. Added missing functionality. ++
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -46,3 +46,5 @@ requirements.txt
|
|||||||
system_prompt.txt
|
system_prompt.txt
|
||||||
CLAUDE*
|
CLAUDE*
|
||||||
SESSION*_COMPLETE.md
|
SESSION*_COMPLETE.md
|
||||||
|
oai/FILE_ATTACHMENTS_FIX.md
|
||||||
|
oai/OLLAMA_ERROR_HANDLING.md
|
||||||
|
|||||||
29
README.md
29
README.md
@@ -3,7 +3,7 @@
|
|||||||
A powerful, modern **Textual TUI** chat client with **multi-provider support** (OpenRouter, Anthropic, OpenAI, Ollama) and **MCP (Model Context Protocol)** integration, enabling AI to access local files and query SQLite databases.
|
A powerful, modern **Textual TUI** chat client with **multi-provider support** (OpenRouter, Anthropic, OpenAI, Ollama) and **MCP (Model Context Protocol)** integration, enabling AI to access local files and query SQLite databases.
|
||||||
|
|
||||||
>[!WARNING]
|
>[!WARNING]
|
||||||
> v3.0.0-b3 is beta software. While I strive for stability, beta versions may contain bugs, incomplete features, or unexpected behavior. I actively work on improvements and appreciate your feedback.
|
> v3.0.0-b4 is beta software. While I strive for stability, beta versions may contain bugs, incomplete features, or unexpected behavior. I actively work on improvements and appreciate your feedback.
|
||||||
>
|
>
|
||||||
>Beta releases are ideal for testing new features and providing feedback. For production use or maximum stability, consider using the latest stable release.
|
>Beta releases are ideal for testing new features and providing feedback. For production use or maximum stability, consider using the latest stable release.
|
||||||
|
|
||||||
@@ -146,6 +146,33 @@ On first run, you'll be prompted for your API key. Configure additional provider
|
|||||||
Ctrl+Q # Quit
|
Ctrl+Q # Quit
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### File Attachments
|
||||||
|
|
||||||
|
Attach files to your messages using `@<file>` or `@file` syntax:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Single file
|
||||||
|
Describe this image @<photo.jpg>
|
||||||
|
Analyze @~/Documents/report.pdf
|
||||||
|
|
||||||
|
# Multiple files
|
||||||
|
Compare @image1.png and @image2.png
|
||||||
|
|
||||||
|
# With paths
|
||||||
|
Review this code @./src/main.py
|
||||||
|
What's in @/Users/username/Downloads/screenshot.png
|
||||||
|
```
|
||||||
|
|
||||||
|
**Supported file types:**
|
||||||
|
- **Images**: `.png`, `.jpg`, `.jpeg`, `.gif`, `.webp`, `.bmp` (for vision models)
|
||||||
|
- **PDFs**: `.pdf` (for document-capable models)
|
||||||
|
- **Code/Text**: `.py`, `.js`, `.md`, `.txt`, `.json`, etc. (all models)
|
||||||
|
|
||||||
|
**Notes:**
|
||||||
|
- Files are automatically base64-encoded for images/PDFs
|
||||||
|
- Maximum file size: 10MB per file
|
||||||
|
- Works with vision-capable models (Claude, GPT-4o, Gemini, etc.)
|
||||||
|
|
||||||
## Web Search
|
## Web Search
|
||||||
|
|
||||||
oAI provides universal web search capabilities for all AI providers with three options:
|
oAI provides universal web search capabilities for all AI providers with three options:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ Author: Rune
|
|||||||
License: MIT
|
License: MIT
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "3.0.0-b3"
|
__version__ = "3.0.0-b4"
|
||||||
__author__ = "Rune Olsen"
|
__author__ = "Rune Olsen"
|
||||||
__license__ = "MIT"
|
__license__ = "MIT"
|
||||||
|
|
||||||
|
|||||||
17
oai/cli.py
17
oai/cli.py
@@ -125,6 +125,23 @@ def _launch_tui(
|
|||||||
provider_api_keys=provider_api_keys,
|
provider_api_keys=provider_api_keys,
|
||||||
ollama_base_url=settings.ollama_base_url,
|
ollama_base_url=settings.ollama_base_url,
|
||||||
)
|
)
|
||||||
|
except ConnectionError as e:
|
||||||
|
# Special handling for Ollama connection errors
|
||||||
|
if selected_provider == "ollama":
|
||||||
|
typer.echo(f"Error: Cannot connect to Ollama server", err=True)
|
||||||
|
typer.echo(f"", err=True)
|
||||||
|
typer.echo(f"Details: {e}", err=True)
|
||||||
|
typer.echo(f"", err=True)
|
||||||
|
typer.echo(f"Please make sure Ollama is running:", err=True)
|
||||||
|
typer.echo(f" ollama serve", err=True)
|
||||||
|
typer.echo(f"", err=True)
|
||||||
|
typer.echo(f"Or configure a different Ollama URL:", err=True)
|
||||||
|
typer.echo(f" oai config ollama_base_url <url>", err=True)
|
||||||
|
typer.echo(f"", err=True)
|
||||||
|
typer.echo(f"Current Ollama URL: {settings.ollama_base_url}", err=True)
|
||||||
|
else:
|
||||||
|
typer.echo(f"Error: Connection failed: {e}", err=True)
|
||||||
|
raise typer.Exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
typer.echo(f"Error: Failed to initialize client: {e}", err=True)
|
typer.echo(f"Error: Failed to initialize client: {e}", err=True)
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|||||||
@@ -1589,10 +1589,30 @@ class ProviderCommand(Command):
|
|||||||
else:
|
else:
|
||||||
# No models available, clear selection
|
# No models available, clear selection
|
||||||
context.session.selected_model = None
|
context.session.selected_model = None
|
||||||
message = f"Switched to {provider_name} provider (no models available)"
|
if provider_name == "ollama":
|
||||||
return CommandResult.success(message=message)
|
# Special message for Ollama with helpful instructions
|
||||||
|
message = f"⚠️ Switched to Ollama, but no models are installed.\n\n"
|
||||||
|
message += "To install models, run:\n"
|
||||||
|
message += " ollama pull llama3.2\n"
|
||||||
|
message += " ollama pull mistral\n\n"
|
||||||
|
message += "See available models at: https://ollama.com/library"
|
||||||
|
return CommandResult.warning(message=message)
|
||||||
|
else:
|
||||||
|
message = f"Switched to {provider_name} provider (no models available)"
|
||||||
|
return CommandResult.warning(message=message)
|
||||||
|
|
||||||
return CommandResult.success(message=f"Switched to {provider_name} provider")
|
return CommandResult.success(message=f"Switched to {provider_name} provider")
|
||||||
|
except ConnectionError as e:
|
||||||
|
# Connection error - likely Ollama server not running
|
||||||
|
if provider_name == "ollama":
|
||||||
|
error_msg = f"❌ Cannot connect to Ollama server.\n\n"
|
||||||
|
error_msg += f"Error: {str(e)}\n\n"
|
||||||
|
error_msg += "Please make sure Ollama is running:\n"
|
||||||
|
error_msg += " ollama serve\n\n"
|
||||||
|
error_msg += f"Or check your Ollama URL in config (currently: {context.settings.ollama_base_url})"
|
||||||
|
return CommandResult.error(error_msg)
|
||||||
|
else:
|
||||||
|
return CommandResult.error(f"Connection error: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return CommandResult.error(f"Failed to switch provider: {e}")
|
return CommandResult.error(f"Failed to switch provider: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class CommandStatus(str, Enum):
|
|||||||
"""Status of command execution."""
|
"""Status of command execution."""
|
||||||
|
|
||||||
SUCCESS = "success"
|
SUCCESS = "success"
|
||||||
|
WARNING = "warning"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
CONTINUE = "continue" # Continue to next handler
|
CONTINUE = "continue" # Continue to next handler
|
||||||
EXIT = "exit" # Exit the application
|
EXIT = "exit" # Exit the application
|
||||||
@@ -50,6 +51,11 @@ class CommandResult:
|
|||||||
"""Create a success result."""
|
"""Create a success result."""
|
||||||
return cls(status=CommandStatus.SUCCESS, message=message, data=data)
|
return cls(status=CommandStatus.SUCCESS, message=message, data=data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def warning(cls, message: str, data: Any = None) -> "CommandResult":
|
||||||
|
"""Create a warning result."""
|
||||||
|
return cls(status=CommandStatus.WARNING, message=message, data=data)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def error(cls, message: str) -> "CommandResult":
|
def error(cls, message: str) -> "CommandResult":
|
||||||
"""Create an error result."""
|
"""Create an error result."""
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from oai.mcp.manager import MCPManager
|
|||||||
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
|
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
|
||||||
from oai.utils.logging import get_logger
|
from oai.utils.logging import get_logger
|
||||||
from oai.utils.web_search import perform_web_search, format_search_results
|
from oai.utils.web_search import perform_web_search, format_search_results
|
||||||
|
from oai.utils.files import parse_file_attachments, prepare_file_attachment
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
@@ -218,8 +219,40 @@ class ChatSession:
|
|||||||
messages.append({"role": "user", "content": entry.prompt})
|
messages.append({"role": "user", "content": entry.prompt})
|
||||||
messages.append({"role": "assistant", "content": entry.response})
|
messages.append({"role": "assistant", "content": entry.response})
|
||||||
|
|
||||||
# Add current message
|
# Parse file attachments from user input
|
||||||
messages.append({"role": "user", "content": user_input})
|
cleaned_text, file_paths = parse_file_attachments(user_input)
|
||||||
|
|
||||||
|
# Build content for current message
|
||||||
|
if file_paths:
|
||||||
|
# Multi-modal message with attachments
|
||||||
|
content_parts = []
|
||||||
|
|
||||||
|
# Add text part if there's any text
|
||||||
|
if cleaned_text.strip():
|
||||||
|
content_parts.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": cleaned_text.strip()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add file attachments
|
||||||
|
for file_path in file_paths:
|
||||||
|
attachment = prepare_file_attachment(
|
||||||
|
file_path,
|
||||||
|
self.selected_model or {}
|
||||||
|
)
|
||||||
|
if attachment:
|
||||||
|
content_parts.append(attachment)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not attach file: {file_path}")
|
||||||
|
|
||||||
|
# If we have content parts, use them; otherwise fall back to text
|
||||||
|
if content_parts:
|
||||||
|
messages.append({"role": "user", "content": content_parts})
|
||||||
|
else:
|
||||||
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
else:
|
||||||
|
# Simple text message
|
||||||
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,9 @@ class OllamaProvider(AIProvider):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if server is accessible
|
True if server is accessible
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConnectionError: If server is not accessible
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{self.base_url}/api/tags", timeout=2)
|
response = requests.get(f"{self.base_url}/api/tags", timeout=2)
|
||||||
@@ -66,11 +69,59 @@ class OllamaProvider(AIProvider):
|
|||||||
logger.info(f"Ollama server accessible at {self.base_url}")
|
logger.info(f"Ollama server accessible at {self.base_url}")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Ollama server returned status {response.status_code}")
|
error_msg = f"Ollama server returned status {response.status_code} at {self.base_url}"
|
||||||
return False
|
logger.error(error_msg)
|
||||||
|
raise ConnectionError(error_msg)
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logger.warning(f"Ollama server not accessible at {self.base_url}: {e}")
|
error_msg = f"Cannot connect to Ollama server at {self.base_url}. Is Ollama running? Error: {e}"
|
||||||
return False
|
logger.error(error_msg)
|
||||||
|
raise ConnectionError(error_msg) from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get provider name."""
|
||||||
|
return "Ollama"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capabilities(self) -> ProviderCapabilities:
|
||||||
|
"""Get provider capabilities."""
|
||||||
|
return ProviderCapabilities(
|
||||||
|
streaming=True,
|
||||||
|
tools=False, # Tool support varies by model
|
||||||
|
images=False, # Image support varies by model
|
||||||
|
online=True, # Web search via DuckDuckGo/Google
|
||||||
|
max_context=8192, # Varies by model
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
List models from local Ollama installation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_text_only: Ignored for Ollama
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available models
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for model_data in data.get("models", []):
|
||||||
|
models.append(self._parse_model(model_data))
|
||||||
|
|
||||||
|
if models:
|
||||||
|
logger.info(f"Found {len(models)} Ollama models")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Ollama server at {self.base_url} has no models installed. Install models with: ollama pull <model_name>")
|
||||||
|
return models
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Failed to connect to Ollama server at {self.base_url}: {e}")
|
||||||
|
logger.error(f"Make sure Ollama is running. Start it with: ollama serve")
|
||||||
|
return []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
|||||||
149
oai/tui/app.py
149
oai/tui/app.py
@@ -287,14 +287,16 @@ class oAIChatApp(App):
|
|||||||
if not user_input:
|
if not user_input:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Clear input field
|
# Clear input field and refocus for next message
|
||||||
input_bar = self.query_one(InputBar)
|
input_bar = self.query_one(InputBar)
|
||||||
chat_input = input_bar.get_input()
|
chat_input = input_bar.get_input()
|
||||||
chat_input.clear()
|
chat_input.clear()
|
||||||
|
|
||||||
# Process the input
|
|
||||||
await self._process_submitted_input(user_input)
|
await self._process_submitted_input(user_input)
|
||||||
|
|
||||||
|
# Keep input focused so user can type while AI responds
|
||||||
|
chat_input.focus()
|
||||||
|
|
||||||
async def _process_submitted_input(self, user_input: str) -> None:
|
async def _process_submitted_input(self, user_input: str) -> None:
|
||||||
"""Process submitted input (command or message).
|
"""Process submitted input (command or message).
|
||||||
|
|
||||||
@@ -313,16 +315,25 @@ class oAIChatApp(App):
|
|||||||
self.history_index = -1
|
self.history_index = -1
|
||||||
self._save_input_to_history(user_input)
|
self._save_input_to_history(user_input)
|
||||||
|
|
||||||
# Always show what the user typed
|
# Show user message
|
||||||
chat_display = self.query_one(ChatDisplay)
|
chat_display = self.query_one(ChatDisplay)
|
||||||
user_widget = UserMessageWidget(user_input)
|
user_widget = UserMessageWidget(user_input)
|
||||||
await chat_display.add_message(user_widget)
|
await chat_display.add_message(user_widget)
|
||||||
|
|
||||||
# Check if it's a command
|
# Check if it's a command or message
|
||||||
if user_input.startswith("/"):
|
if user_input.startswith("/"):
|
||||||
await self.handle_command(user_input)
|
# Defer command processing until after the UI renders
|
||||||
|
self.call_after_refresh(lambda: self.handle_command(user_input))
|
||||||
else:
|
else:
|
||||||
await self.handle_message(user_input)
|
# Defer assistant widget and AI call until after UI renders
|
||||||
|
async def setup_and_call_ai():
|
||||||
|
model_name = self.session.selected_model.get("name", "Assistant") if self.session.selected_model else "Assistant"
|
||||||
|
assistant_widget = AssistantMessageWidget(model_name, chat_display=chat_display)
|
||||||
|
await chat_display.add_message(assistant_widget)
|
||||||
|
assistant_widget.set_content("_Thinking... (Press Esc to stop)_")
|
||||||
|
await self.handle_message(user_input, assistant_widget)
|
||||||
|
|
||||||
|
self.call_after_refresh(setup_and_call_ai)
|
||||||
|
|
||||||
async def handle_command(self, command_text: str) -> None:
|
async def handle_command(self, command_text: str) -> None:
|
||||||
"""Handle a slash command."""
|
"""Handle a slash command."""
|
||||||
@@ -504,6 +515,11 @@ class oAIChatApp(App):
|
|||||||
chat_display = self.query_one(ChatDisplay)
|
chat_display = self.query_one(ChatDisplay)
|
||||||
error_widget = SystemMessageWidget(f"❌ {result.message}")
|
error_widget = SystemMessageWidget(f"❌ {result.message}")
|
||||||
await chat_display.add_message(error_widget)
|
await chat_display.add_message(error_widget)
|
||||||
|
elif result.status == CommandStatus.WARNING:
|
||||||
|
# Display warning in chat
|
||||||
|
chat_display = self.query_one(ChatDisplay)
|
||||||
|
warning_widget = SystemMessageWidget(f"⚠️ {result.message}")
|
||||||
|
await chat_display.add_message(warning_widget)
|
||||||
elif result.message:
|
elif result.message:
|
||||||
# Display success message
|
# Display success message
|
||||||
chat_display = self.query_one(ChatDisplay)
|
chat_display = self.query_one(ChatDisplay)
|
||||||
@@ -541,75 +557,78 @@ class oAIChatApp(App):
|
|||||||
# Update online mode indicator
|
# Update online mode indicator
|
||||||
input_bar.update_online_mode(self.session.online_enabled)
|
input_bar.update_online_mode(self.session.online_enabled)
|
||||||
|
|
||||||
async def handle_message(self, user_input: str) -> None:
|
async def handle_message(self, user_input: str, assistant_widget: AssistantMessageWidget = None) -> None:
|
||||||
"""Handle a chat message (user message already added by caller)."""
|
"""Handle a chat message (user message and assistant widget already added by caller)."""
|
||||||
chat_display = self.query_one(ChatDisplay)
|
chat_display = self.query_one(ChatDisplay)
|
||||||
|
|
||||||
# Create assistant message widget with loading indicator
|
# If no assistant widget provided (legacy), create one
|
||||||
model_name = self.session.selected_model.get("name", "Assistant") if self.session.selected_model else "Assistant"
|
if assistant_widget is None:
|
||||||
assistant_widget = AssistantMessageWidget(model_name, chat_display=chat_display)
|
model_name = self.session.selected_model.get("name", "Assistant") if self.session.selected_model else "Assistant"
|
||||||
await chat_display.add_message(assistant_widget)
|
assistant_widget = AssistantMessageWidget(model_name, chat_display=chat_display)
|
||||||
|
await chat_display.add_message(assistant_widget)
|
||||||
|
assistant_widget.set_content("_Thinking... (Press Esc to stop)_")
|
||||||
|
|
||||||
# Show loading indicator with cancellation hint
|
|
||||||
assistant_widget.set_content("_Thinking... (Press Esc to stop)_")
|
|
||||||
|
|
||||||
# Set generation flags
|
|
||||||
self._is_generating = True
|
self._is_generating = True
|
||||||
self._cancel_generation = False
|
self._cancel_generation = False
|
||||||
|
|
||||||
try:
|
# Run streaming in background to keep UI responsive
|
||||||
# Stream response
|
async def stream_task():
|
||||||
response_iterator = self.session.send_message_async(
|
try:
|
||||||
user_input,
|
# Stream response
|
||||||
stream=self.settings.stream_enabled,
|
response_iterator = self.session.send_message_async(
|
||||||
)
|
user_input,
|
||||||
|
stream=self.settings.stream_enabled,
|
||||||
# Stream and collect response with cancellation support
|
|
||||||
full_text, usage = await assistant_widget.stream_response(
|
|
||||||
response_iterator,
|
|
||||||
cancel_check=lambda: self._cancel_generation
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add to history if we got a response
|
|
||||||
if full_text:
|
|
||||||
# Extract cost from usage or calculate from pricing
|
|
||||||
cost = 0.0
|
|
||||||
if usage and hasattr(usage, 'total_cost_usd') and usage.total_cost_usd:
|
|
||||||
cost = usage.total_cost_usd
|
|
||||||
self.notify(f"Cost from API: ${cost:.6f}", severity="information")
|
|
||||||
elif usage and self.session.selected_model:
|
|
||||||
# Calculate cost from model pricing
|
|
||||||
pricing = self.session.selected_model.get("pricing", {})
|
|
||||||
prompt_cost = float(pricing.get("prompt", 0))
|
|
||||||
completion_cost = float(pricing.get("completion", 0))
|
|
||||||
|
|
||||||
# Prices are per token, convert to dollars
|
|
||||||
prompt_total = usage.prompt_tokens * prompt_cost
|
|
||||||
completion_total = usage.completion_tokens * completion_cost
|
|
||||||
cost = prompt_total + completion_total
|
|
||||||
|
|
||||||
if cost > 0:
|
|
||||||
self.notify(f"Cost calculated: ${cost:.6f}", severity="information")
|
|
||||||
|
|
||||||
self.session.add_to_history(
|
|
||||||
prompt=user_input,
|
|
||||||
response=full_text,
|
|
||||||
usage=usage,
|
|
||||||
cost=cost,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update footer
|
# Stream and collect response with cancellation support
|
||||||
self._update_footer()
|
full_text, usage = await assistant_widget.stream_response(
|
||||||
|
response_iterator,
|
||||||
|
cancel_check=lambda: self._cancel_generation
|
||||||
|
)
|
||||||
|
|
||||||
# Check if generation was cancelled
|
# Add to history if we got a response
|
||||||
if self._cancel_generation and full_text:
|
if full_text:
|
||||||
assistant_widget.set_content(full_text + "\n\n_[Generation stopped by user]_")
|
# Extract cost from usage or calculate from pricing
|
||||||
|
cost = 0.0
|
||||||
|
if usage and hasattr(usage, 'total_cost_usd') and usage.total_cost_usd:
|
||||||
|
cost = usage.total_cost_usd
|
||||||
|
self.notify(f"Cost from API: ${cost:.6f}", severity="information")
|
||||||
|
elif usage and self.session.selected_model:
|
||||||
|
# Calculate cost from model pricing
|
||||||
|
pricing = self.session.selected_model.get("pricing", {})
|
||||||
|
prompt_cost = float(pricing.get("prompt", 0))
|
||||||
|
completion_cost = float(pricing.get("completion", 0))
|
||||||
|
|
||||||
except Exception as e:
|
# Prices are per token, convert to dollars
|
||||||
assistant_widget.set_content(f"❌ Error: {str(e)}")
|
prompt_total = usage.prompt_tokens * prompt_cost
|
||||||
finally:
|
completion_total = usage.completion_tokens * completion_cost
|
||||||
self._is_generating = False
|
cost = prompt_total + completion_total
|
||||||
self._cancel_generation = False
|
|
||||||
|
if cost > 0:
|
||||||
|
self.notify(f"Cost calculated: ${cost:.6f}", severity="information")
|
||||||
|
|
||||||
|
self.session.add_to_history(
|
||||||
|
prompt=user_input,
|
||||||
|
response=full_text,
|
||||||
|
usage=usage,
|
||||||
|
cost=cost,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update footer
|
||||||
|
self._update_footer()
|
||||||
|
|
||||||
|
# Check if generation was cancelled
|
||||||
|
if self._cancel_generation and full_text:
|
||||||
|
assistant_widget.set_content(full_text + "\n\n_[Generation stopped by user]_")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
assistant_widget.set_content(f"❌ Error: {str(e)}")
|
||||||
|
finally:
|
||||||
|
self._is_generating = False
|
||||||
|
self._cancel_generation = False
|
||||||
|
|
||||||
|
# Create background task - don't await it!
|
||||||
|
asyncio.create_task(stream_task())
|
||||||
|
|
||||||
def _update_footer(self) -> None:
|
def _update_footer(self) -> None:
|
||||||
"""Update footer statistics."""
|
"""Update footer statistics."""
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ class ChatDisplay(ScrollableContainer):
|
|||||||
async def add_message(self, widget: Static) -> None:
|
async def add_message(self, widget: Static) -> None:
|
||||||
"""Add a message widget to the display."""
|
"""Add a message widget to the display."""
|
||||||
await self.mount(widget)
|
await self.mount(widget)
|
||||||
|
self.refresh(layout=True)
|
||||||
self.scroll_end(animate=False)
|
self.scroll_end(animate=False)
|
||||||
|
|
||||||
def clear_messages(self) -> None:
|
def clear_messages(self) -> None:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Message widgets for oAI TUI."""
|
"""Message widgets for oAI TUI."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import Any, AsyncIterator, Tuple
|
from typing import Any, AsyncIterator, Tuple
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
@@ -101,6 +102,9 @@ class AssistantMessageWidget(Static):
|
|||||||
if hasattr(chunk, "usage") and chunk.usage:
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
usage = chunk.usage
|
usage = chunk.usage
|
||||||
|
|
||||||
|
# Yield control to event loop so UI stays responsive
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
return self.full_text, usage
|
return self.full_text, usage
|
||||||
|
|
||||||
def set_content(self, content: str) -> None:
|
def set_content(self, content: str) -> None:
|
||||||
|
|||||||
@@ -278,11 +278,16 @@ def prepare_file_attachment(
|
|||||||
file_data = f.read()
|
file_data = f.read()
|
||||||
|
|
||||||
if category == "image":
|
if category == "image":
|
||||||
# Check if model supports images
|
# Check if model supports images - try multiple possible locations
|
||||||
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
|
input_modalities = (
|
||||||
if "image" not in input_modalities:
|
model_capabilities.get("input_modalities", []) or
|
||||||
logger.warning(f"Model does not support images")
|
model_capabilities.get("architecture", {}).get("input_modalities", [])
|
||||||
return None
|
)
|
||||||
|
|
||||||
|
# If no input_modalities found or image not in list, try to attach anyway
|
||||||
|
# Some models support images but don't advertise it properly
|
||||||
|
if input_modalities and "image" not in input_modalities:
|
||||||
|
logger.warning(f"Model may not support images, attempting anyway...")
|
||||||
|
|
||||||
b64_data = base64.b64encode(file_data).decode("utf-8")
|
b64_data = base64.b64encode(file_data).decode("utf-8")
|
||||||
return {
|
return {
|
||||||
@@ -291,12 +296,16 @@ def prepare_file_attachment(
|
|||||||
}
|
}
|
||||||
|
|
||||||
elif category == "pdf":
|
elif category == "pdf":
|
||||||
# Check if model supports PDFs
|
# Check if model supports PDFs - try multiple possible locations
|
||||||
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
|
input_modalities = (
|
||||||
|
model_capabilities.get("input_modalities", []) or
|
||||||
|
model_capabilities.get("architecture", {}).get("input_modalities", [])
|
||||||
|
)
|
||||||
|
|
||||||
supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"])
|
supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"])
|
||||||
if not supports_pdf:
|
|
||||||
logger.warning(f"Model does not support PDFs")
|
if input_modalities and not supports_pdf:
|
||||||
return None
|
logger.warning(f"Model may not support PDFs, attempting anyway...")
|
||||||
|
|
||||||
b64_data = base64.b64encode(file_data).decode("utf-8")
|
b64_data = base64.b64encode(file_data).decode("utf-8")
|
||||||
return {
|
return {
|
||||||
@@ -321,3 +330,49 @@ def prepare_file_attachment(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error preparing file attachment {path}: {e}")
|
logger.error(f"Error preparing file attachment {path}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_file_attachments(user_input: str) -> Tuple[str, list[Path]]:
|
||||||
|
"""
|
||||||
|
Parse user input for @<file> or @file syntax and extract file paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: User's message that may contain @<file> or @file references
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (cleaned_text, list_of_file_paths)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> parse_file_attachments("Look at @<image.png> and @/path/to/doc.pdf")
|
||||||
|
("Look at and ", [Path("image.png"), Path("/path/to/doc.pdf")])
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
# Pattern to match both @<filepath> and @filepath
|
||||||
|
# Matches @<...> or @followed by path-like strings
|
||||||
|
pattern = r'@<([^>]+)>|@([/~.][\S]+|[a-zA-Z]:[/\\][\S]+)'
|
||||||
|
|
||||||
|
# Find all file references
|
||||||
|
file_paths = []
|
||||||
|
matches_to_remove = []
|
||||||
|
|
||||||
|
for match in re.finditer(pattern, user_input):
|
||||||
|
# Group 1 is for @<filepath>, Group 2 is for @filepath
|
||||||
|
file_path_str = (match.group(1) or match.group(2)).strip()
|
||||||
|
file_path = Path(file_path_str).expanduser().resolve()
|
||||||
|
|
||||||
|
if file_path.exists():
|
||||||
|
file_paths.append(file_path)
|
||||||
|
logger.info(f"Found file attachment: {file_path}")
|
||||||
|
matches_to_remove.append(match.group(0))
|
||||||
|
else:
|
||||||
|
logger.warning(f"File not found: {file_path_str}")
|
||||||
|
matches_to_remove.append(match.group(0))
|
||||||
|
|
||||||
|
# Remove @<file> and @file references from the text
|
||||||
|
cleaned_text = user_input
|
||||||
|
for match_str in matches_to_remove:
|
||||||
|
cleaned_text = cleaned_text.replace(match_str, '', 1)
|
||||||
|
|
||||||
|
return cleaned_text, file_paths
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "oai"
|
name = "oai"
|
||||||
version = "3.0.0-b3" # MUST match oai/__init__.py __version__
|
version = "3.0.0-b4" # MUST match oai/__init__.py __version__
|
||||||
description = "Open AI Chat Client - Multi-provider terminal chat with MCP support"
|
description = "Open AI Chat Client - Multi-provider terminal chat with MCP support"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
|
|||||||
Reference in New Issue
Block a user