5 Commits

Author SHA1 Message Date
fd99d5c778 Bug fixes. Added missing functionality. ++ 2026-02-09 11:11:50 +01:00
b95568d1ba Fixed some tool use errors 2026-02-06 13:29:14 +01:00
f369488d5b updated README 2026-02-06 10:18:38 +01:00
4788470dc5 updated README 2026-02-06 09:57:24 +01:00
603d42b7ff Bug fixes for v3.0.0 2026-02-06 09:48:37 +01:00
20 changed files with 606 additions and 188 deletions

2
.gitignore vendored
View File

@@ -46,3 +46,5 @@ requirements.txt
system_prompt.txt system_prompt.txt
CLAUDE* CLAUDE*
SESSION*_COMPLETE.md SESSION*_COMPLETE.md
oai/FILE_ATTACHMENTS_FIX.md
oai/OLLAMA_ERROR_HANDLING.md

View File

@@ -2,6 +2,11 @@
A powerful, modern **Textual TUI** chat client with **multi-provider support** (OpenRouter, Anthropic, OpenAI, Ollama) and **MCP (Model Context Protocol)** integration, enabling AI to access local files and query SQLite databases. A powerful, modern **Textual TUI** chat client with **multi-provider support** (OpenRouter, Anthropic, OpenAI, Ollama) and **MCP (Model Context Protocol)** integration, enabling AI to access local files and query SQLite databases.
>[!WARNING]
> v3.0.0-b4 is beta software. While I strive for stability, beta versions may contain bugs, incomplete features, or unexpected behavior. I actively work on improvements and appreciate your feedback.
>
>Beta releases are ideal for testing new features and providing feedback. For production use or maximum stability, consider using the latest stable release.
## Features ## Features
### Core Features ### Core Features
@@ -141,6 +146,33 @@ On first run, you'll be prompted for your API key. Configure additional provider
Ctrl+Q # Quit Ctrl+Q # Quit
``` ```
### File Attachments
Attach files to your messages using `@<file>` or `@file` syntax:
```bash
# Single file
Describe this image @<photo.jpg>
Analyze @~/Documents/report.pdf
# Multiple files
Compare @image1.png and @image2.png
# With paths
Review this code @./src/main.py
What's in @/Users/username/Downloads/screenshot.png
```
**Supported file types:**
- **Images**: `.png`, `.jpg`, `.jpeg`, `.gif`, `.webp`, `.bmp` (for vision models)
- **PDFs**: `.pdf` (for document-capable models)
- **Code/Text**: `.py`, `.js`, `.md`, `.txt`, `.json`, etc. (all models)
**Notes:**
- Files are automatically base64-encoded for images/PDFs
- Maximum file size: 10MB per file
- Works with vision-capable models (Claude, GPT-4o, Gemini, etc.)
## Web Search ## Web Search
oAI provides universal web search capabilities for all AI providers with three options: oAI provides universal web search capabilities for all AI providers with three options:
@@ -341,7 +373,6 @@ oai/
│ │ └── styles.tcss # TUI styling │ │ └── styles.tcss # TUI styling
│ └── utils/ # Logging, export, etc. │ └── utils/ # Logging, export, etc.
├── pyproject.toml # Package configuration ├── pyproject.toml # Package configuration
├── build.sh # Binary build script
└── README.md └── README.md
``` ```

View File

@@ -9,7 +9,7 @@ Author: Rune
License: MIT License: MIT
""" """
__version__ = "3.0.0-b3" __version__ = "3.0.0-b4"
__author__ = "Rune Olsen" __author__ = "Rune Olsen"
__license__ = "MIT" __license__ = "MIT"

View File

@@ -125,6 +125,23 @@ def _launch_tui(
provider_api_keys=provider_api_keys, provider_api_keys=provider_api_keys,
ollama_base_url=settings.ollama_base_url, ollama_base_url=settings.ollama_base_url,
) )
except ConnectionError as e:
# Special handling for Ollama connection errors
if selected_provider == "ollama":
typer.echo(f"Error: Cannot connect to Ollama server", err=True)
typer.echo(f"", err=True)
typer.echo(f"Details: {e}", err=True)
typer.echo(f"", err=True)
typer.echo(f"Please make sure Ollama is running:", err=True)
typer.echo(f" ollama serve", err=True)
typer.echo(f"", err=True)
typer.echo(f"Or configure a different Ollama URL:", err=True)
typer.echo(f" oai config ollama_base_url <url>", err=True)
typer.echo(f"", err=True)
typer.echo(f"Current Ollama URL: {settings.ollama_base_url}", err=True)
else:
typer.echo(f"Error: Connection failed: {e}", err=True)
raise typer.Exit(1)
except Exception as e: except Exception as e:
typer.echo(f"Error: Failed to initialize client: {e}", err=True) typer.echo(f"Error: Failed to initialize client: {e}", err=True)
raise typer.Exit(1) raise typer.Exit(1)

View File

@@ -139,7 +139,6 @@ class HelpCommand(Command):
("/retry", "Resend the last prompt.", "/retry"), ("/retry", "Resend the last prompt.", "/retry"),
("/memory", "Toggle conversation memory.", "/memory on"), ("/memory", "Toggle conversation memory.", "/memory on"),
("/online", "Toggle online mode (web search).", "/online on"), ("/online", "Toggle online mode (web search).", "/online on"),
("/paste", "Paste from clipboard with optional prompt.", "/paste Explain"),
]), ]),
("[bold cyan]━━━ NAVIGATION ━━━[/]", [ ("[bold cyan]━━━ NAVIGATION ━━━[/]", [
("/prev", "View previous response in history.", "/prev"), ("/prev", "View previous response in history.", "/prev"),
@@ -150,6 +149,7 @@ class HelpCommand(Command):
("/model", "Select AI model.", "/model gpt"), ("/model", "Select AI model.", "/model gpt"),
("/info", "Show model information.", "/info"), ("/info", "Show model information.", "/info"),
("/config", "View or change settings.", "/config stream on"), ("/config", "View or change settings.", "/config stream on"),
("/config log", "Set log level.", "/config log debug"),
("/maxtoken", "Set session token limit.", "/maxtoken 2000"), ("/maxtoken", "Set session token limit.", "/maxtoken 2000"),
("/system", "Set system prompt.", "/system You are an expert"), ("/system", "Set system prompt.", "/system You are an expert"),
]), ]),
@@ -1022,13 +1022,20 @@ class ConfigCommand(Command):
else: else:
pass pass
elif setting == "loglevel": elif setting in ["loglevel", "log"]:
valid_levels = ["debug", "info", "warning", "error", "critical"] valid_levels = ["debug", "info", "warning", "error", "critical"]
if value and value.lower() in valid_levels: if value:
settings.set_log_level(value.lower()) if value.lower() in valid_levels:
print_success(f"Log level set to: {value.lower()}") settings.set_log_level(value.lower())
message = f"✓ Log level set to: {value.lower()}"
return CommandResult.success(message=message)
else:
message = f"Invalid log level. Valid levels: {', '.join(valid_levels)}"
return CommandResult.error(message)
else: else:
print_info(f"Valid levels: {', '.join(valid_levels)}") # Show current log level
message = f"Current log level: {settings.log_level}\nValid levels: {', '.join(valid_levels)}"
return CommandResult.success(message=message)
else: else:
pass pass
@@ -1582,10 +1589,30 @@ class ProviderCommand(Command):
else: else:
# No models available, clear selection # No models available, clear selection
context.session.selected_model = None context.session.selected_model = None
message = f"Switched to {provider_name} provider (no models available)" if provider_name == "ollama":
return CommandResult.success(message=message) # Special message for Ollama with helpful instructions
message = f"⚠️ Switched to Ollama, but no models are installed.\n\n"
message += "To install models, run:\n"
message += " ollama pull llama3.2\n"
message += " ollama pull mistral\n\n"
message += "See available models at: https://ollama.com/library"
return CommandResult.warning(message=message)
else:
message = f"Switched to {provider_name} provider (no models available)"
return CommandResult.warning(message=message)
return CommandResult.success(message=f"Switched to {provider_name} provider") return CommandResult.success(message=f"Switched to {provider_name} provider")
except ConnectionError as e:
# Connection error - likely Ollama server not running
if provider_name == "ollama":
error_msg = f"❌ Cannot connect to Ollama server.\n\n"
error_msg += f"Error: {str(e)}\n\n"
error_msg += "Please make sure Ollama is running:\n"
error_msg += " ollama serve\n\n"
error_msg += f"Or check your Ollama URL in config (currently: {context.settings.ollama_base_url})"
return CommandResult.error(error_msg)
else:
return CommandResult.error(f"Connection error: {e}")
except Exception as e: except Exception as e:
return CommandResult.error(f"Failed to switch provider: {e}") return CommandResult.error(f"Failed to switch provider: {e}")
@@ -1600,32 +1627,15 @@ class PasteCommand(Command):
@property @property
def help(self) -> CommandHelp: def help(self) -> CommandHelp:
return CommandHelp( return CommandHelp(
description="Paste from clipboard and send to AI.", description="[Disabled] Use Cmd+V (macOS) or Ctrl+V (Linux/Windows) to paste directly.",
usage="/paste [prompt]", usage="/paste",
notes="This command is disabled. Use keyboard shortcuts to paste instead.",
) )
def execute(self, args: str, context: CommandContext) -> CommandResult: def execute(self, args: str, context: CommandContext) -> CommandResult:
try: # Disabled - use Cmd+V (macOS) or Ctrl+V (Linux/Windows) instead
import pyperclip message = "💡 Tip: Use Cmd+V (macOS) or Ctrl+V (Linux/Windows) to paste directly"
content = pyperclip.paste() return CommandResult.success(message=message)
except ImportError:
message = "pyperclip not installed"
return CommandResult.error(message)
except Exception as e:
message = f"Failed to access clipboard: {e}"
return CommandResult.error(str(e))
if not content:
message = "Clipboard is empty"
return CommandResult.error(message)
# Build the prompt
if args:
full_prompt = f"{args}\n\n```\n{content}\n```"
else:
full_prompt = content
return CommandResult.success(data={"paste_prompt": full_prompt})
class ModelCommand(Command): class ModelCommand(Command):
"""Select AI model.""" """Select AI model."""

View File

@@ -23,6 +23,7 @@ class CommandStatus(str, Enum):
"""Status of command execution.""" """Status of command execution."""
SUCCESS = "success" SUCCESS = "success"
WARNING = "warning"
ERROR = "error" ERROR = "error"
CONTINUE = "continue" # Continue to next handler CONTINUE = "continue" # Continue to next handler
EXIT = "exit" # Exit the application EXIT = "exit" # Exit the application
@@ -50,6 +51,11 @@ class CommandResult:
"""Create a success result.""" """Create a success result."""
return cls(status=CommandStatus.SUCCESS, message=message, data=data) return cls(status=CommandStatus.SUCCESS, message=message, data=data)
@classmethod
def warning(cls, message: str, data: Any = None) -> "CommandResult":
"""Create a warning result."""
return cls(status=CommandStatus.WARNING, message=message, data=data)
@classmethod @classmethod
def error(cls, message: str) -> "CommandResult": def error(cls, message: str) -> "CommandResult":
"""Create an error result.""" """Create an error result."""

View File

@@ -221,18 +221,24 @@ class AIClient:
f"messages={len(chat_messages)}, stream={stream}" f"messages={len(chat_messages)}, stream={stream}"
) )
return self.provider.chat( # Build provider chat parameters
model=model_id, chat_params = {
messages=chat_messages, "model": model_id,
stream=stream, "messages": chat_messages,
max_tokens=max_tokens, "stream": stream,
temperature=temperature, "max_tokens": max_tokens,
tools=tools, "temperature": temperature,
tool_choice=tool_choice, "tools": tools,
transforms=transforms, "tool_choice": tool_choice,
enable_web_search=enable_web_search, "transforms": transforms,
web_search_config=web_search_config or {}, }
)
# Only pass web search params to Anthropic provider
if self.provider_name == "anthropic":
chat_params["enable_web_search"] = enable_web_search
chat_params["web_search_config"] = web_search_config or {}
return self.provider.chat(**chat_params)
def chat_with_tools( def chat_with_tools(
self, self,

View File

@@ -24,6 +24,7 @@ from oai.mcp.manager import MCPManager
from oai.providers.base import ChatResponse, StreamChunk, UsageStats from oai.providers.base import ChatResponse, StreamChunk, UsageStats
from oai.utils.logging import get_logger from oai.utils.logging import get_logger
from oai.utils.web_search import perform_web_search, format_search_results from oai.utils.web_search import perform_web_search, format_search_results
from oai.utils.files import parse_file_attachments, prepare_file_attachment
logger = get_logger() logger = get_logger()
@@ -218,8 +219,40 @@ class ChatSession:
messages.append({"role": "user", "content": entry.prompt}) messages.append({"role": "user", "content": entry.prompt})
messages.append({"role": "assistant", "content": entry.response}) messages.append({"role": "assistant", "content": entry.response})
# Add current message # Parse file attachments from user input
messages.append({"role": "user", "content": user_input}) cleaned_text, file_paths = parse_file_attachments(user_input)
# Build content for current message
if file_paths:
# Multi-modal message with attachments
content_parts = []
# Add text part if there's any text
if cleaned_text.strip():
content_parts.append({
"type": "text",
"text": cleaned_text.strip()
})
# Add file attachments
for file_path in file_paths:
attachment = prepare_file_attachment(
file_path,
self.selected_model or {}
)
if attachment:
content_parts.append(attachment)
else:
logger.warning(f"Could not attach file: {file_path}")
# If we have content parts, use them; otherwise fall back to text
if content_parts:
messages.append({"role": "user", "content": content_parts})
else:
messages.append({"role": "user", "content": user_input})
else:
# Simple text message
messages.append({"role": "user", "content": user_input})
return messages return messages
@@ -517,15 +550,21 @@ class ChatSession:
Returns: Returns:
Tuple of (full_text, usage) Tuple of (full_text, usage)
""" """
response = self.client.chat( # Build chat parameters
messages=messages, chat_params = {
model=model_id, "messages": messages,
stream=True, "model": model_id,
max_tokens=max_tokens, "stream": True,
transforms=transforms, "max_tokens": max_tokens,
enable_web_search=enable_web_search, "transforms": transforms,
web_search_config=web_search_config or {}, }
)
# Only pass web search params to Anthropic provider
if self.client.provider_name == "anthropic":
chat_params["enable_web_search"] = enable_web_search
chat_params["web_search_config"] = web_search_config or {}
response = self.client.chat(**chat_params)
if isinstance(response, ChatResponse): if isinstance(response, ChatResponse):
return response.content or "", response.usage return response.content or "", response.usage
@@ -647,8 +686,9 @@ class ChatSession:
): ):
yield chunk yield chunk
else: else:
# Non-streaming request # Non-streaming request - run in thread to avoid blocking event loop
response = self.client.chat( response = await asyncio.to_thread(
self.client.chat,
messages=messages, messages=messages,
model=model_id, model=model_id,
stream=False, stream=False,
@@ -691,7 +731,9 @@ class ChatSession:
api_messages = list(messages) api_messages = list(messages)
while loop_count < max_loops: while loop_count < max_loops:
response = self.client.chat( # Run in thread to avoid blocking event loop
response = await asyncio.to_thread(
self.client.chat,
messages=api_messages, messages=api_messages,
model=model_id, model=model_id,
stream=False, stream=False,
@@ -837,33 +879,55 @@ class ChatSession:
Yields: Yields:
StreamChunk objects StreamChunk objects
""" """
response = self.client.chat( # Build chat parameters
messages=messages, chat_params = {
model=model_id, "messages": messages,
stream=True, "model": model_id,
max_tokens=max_tokens, "stream": True,
transforms=transforms, "max_tokens": max_tokens,
enable_web_search=enable_web_search, "transforms": transforms,
web_search_config=web_search_config or {}, }
)
if isinstance(response, ChatResponse): # Only pass web search params to Anthropic provider
# Non-streaming response if self.client.provider_name == "anthropic":
chunk = StreamChunk( chat_params["enable_web_search"] = enable_web_search
id="", chat_params["web_search_config"] = web_search_config or {}
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
return
# Stream the response # For streaming, call directly (generator yields control naturally)
for chunk in response: # For non-streaming, we'll detect it and run in thread
if chunk.error: if chat_params.get("stream", True):
yield StreamChunk(id="", delta_content=None, usage=None, error=chunk.error) # Streaming - call directly, iteration will yield control
break response = self.client.chat(**chat_params)
yield chunk
if isinstance(response, ChatResponse):
# Provider returned non-streaming despite stream=True
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
return
# Stream the response - yield control between chunks
for chunk in response:
await asyncio.sleep(0) # Yield control to event loop
if chunk.error:
yield StreamChunk(id="", delta_content=None, usage=None, error=chunk.error)
break
yield chunk
else:
# Non-streaming - run in thread to avoid blocking
response = await asyncio.to_thread(self.client.chat, **chat_params)
if isinstance(response, ChatResponse):
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
# ========== END ASYNC METHODS ========== # ========== END ASYNC METHODS ==========

View File

@@ -647,6 +647,9 @@ class AnthropicProvider(AIProvider):
"description": model.description, "description": model.description,
"context_length": model.context_length, "context_length": model.context_length,
"pricing": model.pricing, "pricing": model.pricing,
"supported_parameters": model.supported_parameters,
"input_modalities": model.input_modalities,
"output_modalities": model.output_modalities,
} }
for model in models for model in models
] ]
@@ -669,5 +672,8 @@ class AnthropicProvider(AIProvider):
"description": model.description, "description": model.description,
"context_length": model.context_length, "context_length": model.context_length,
"pricing": model.pricing, "pricing": model.pricing,
"supported_parameters": model.supported_parameters,
"input_modalities": model.input_modalities,
"output_modalities": model.output_modalities,
} }
return None return None

View File

@@ -59,6 +59,9 @@ class OllamaProvider(AIProvider):
Returns: Returns:
True if server is accessible True if server is accessible
Raises:
ConnectionError: If server is not accessible
""" """
try: try:
response = requests.get(f"{self.base_url}/api/tags", timeout=2) response = requests.get(f"{self.base_url}/api/tags", timeout=2)
@@ -66,11 +69,59 @@ class OllamaProvider(AIProvider):
logger.info(f"Ollama server accessible at {self.base_url}") logger.info(f"Ollama server accessible at {self.base_url}")
return True return True
else: else:
logger.warning(f"Ollama server returned status {response.status_code}") error_msg = f"Ollama server returned status {response.status_code} at {self.base_url}"
return False logger.error(error_msg)
raise ConnectionError(error_msg)
except requests.RequestException as e: except requests.RequestException as e:
logger.warning(f"Ollama server not accessible at {self.base_url}: {e}") error_msg = f"Cannot connect to Ollama server at {self.base_url}. Is Ollama running? Error: {e}"
return False logger.error(error_msg)
raise ConnectionError(error_msg) from e
@property
def name(self) -> str:
"""Get provider name."""
return "Ollama"
@property
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
return ProviderCapabilities(
streaming=True,
tools=False, # Tool support varies by model
images=False, # Image support varies by model
online=True, # Web search via DuckDuckGo/Google
max_context=8192, # Varies by model
)
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
List models from local Ollama installation.
Args:
filter_text_only: Ignored for Ollama
Returns:
List of available models
"""
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
response.raise_for_status()
data = response.json()
models = []
for model_data in data.get("models", []):
models.append(self._parse_model(model_data))
if models:
logger.info(f"Found {len(models)} Ollama models")
else:
logger.warning(f"Ollama server at {self.base_url} has no models installed. Install models with: ollama pull <model_name>")
return models
except requests.RequestException as e:
logger.error(f"Failed to connect to Ollama server at {self.base_url}: {e}")
logger.error(f"Make sure Ollama is running. Start it with: ollama serve")
return []
@property @property
def name(self) -> str: def name(self) -> str:
@@ -397,6 +448,9 @@ class OllamaProvider(AIProvider):
"description": model.description, "description": model.description,
"context_length": model.context_length, "context_length": model.context_length,
"pricing": model.pricing, "pricing": model.pricing,
"supported_parameters": model.supported_parameters,
"input_modalities": model.input_modalities,
"output_modalities": model.output_modalities,
} }
for model in models for model in models
] ]
@@ -419,5 +473,8 @@ class OllamaProvider(AIProvider):
"description": model.description, "description": model.description,
"context_length": model.context_length, "context_length": model.context_length,
"pricing": model.pricing, "pricing": model.pricing,
"supported_parameters": model.supported_parameters,
"input_modalities": model.input_modalities,
"output_modalities": model.output_modalities,
} }
return None return None

View File

@@ -604,6 +604,9 @@ class OpenAIProvider(AIProvider):
"description": model.description, "description": model.description,
"context_length": model.context_length, "context_length": model.context_length,
"pricing": model.pricing, "pricing": model.pricing,
"supported_parameters": model.supported_parameters,
"input_modalities": model.input_modalities,
"output_modalities": model.output_modalities,
} }
for model in models for model in models
] ]
@@ -626,5 +629,8 @@ class OpenAIProvider(AIProvider):
"description": model.description, "description": model.description,
"context_length": model.context_length, "context_length": model.context_length,
"pricing": model.pricing, "pricing": model.pricing,
"supported_parameters": model.supported_parameters,
"input_modalities": model.input_modalities,
"output_modalities": model.output_modalities,
} }
return None return None

View File

@@ -7,7 +7,7 @@ from typing import Optional
import pyperclip import pyperclip
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
from textual.widgets import Input from textual.widgets import TextArea
from oai import __version__ from oai import __version__
from oai.commands.registry import CommandStatus, registry from oai.commands.registry import CommandStatus, registry
@@ -36,6 +36,7 @@ from oai.tui.widgets import (
SystemMessageWidget, SystemMessageWidget,
UserMessageWidget, UserMessageWidget,
) )
from oai.tui.widgets.input_bar import ChatTextArea
from oai.tui.widgets.command_dropdown import CommandDropdown from oai.tui.widgets.command_dropdown import CommandDropdown
@@ -66,6 +67,8 @@ class oAIChatApp(App):
self.input_history: list[str] = [] self.input_history: list[str] = []
self.history_index: int = -1 self.history_index: int = -1
self._navigating_history: bool = False self._navigating_history: bool = False
self._cancel_generation: bool = False
self._is_generating: bool = False
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
"""Compose the TUI layout.""" """Compose the TUI layout."""
@@ -80,6 +83,9 @@ class oAIChatApp(App):
def on_mount(self) -> None: def on_mount(self) -> None:
"""Handle app mount.""" """Handle app mount."""
# Load input history from file
self._load_input_history()
# Focus the input # Focus the input
input_bar = self.query_one(InputBar) input_bar = self.query_one(InputBar)
chat_input = input_bar.get_input() chat_input = input_bar.get_input()
@@ -97,11 +103,47 @@ class oAIChatApp(App):
if self.session.online_enabled: if self.session.online_enabled:
input_bar.update_online_mode(True) input_bar.update_online_mode(True)
def _load_input_history(self) -> None:
"""Load input history from history.txt file."""
from oai.constants import HISTORY_FILE
try:
if HISTORY_FILE.exists():
with open(HISTORY_FILE, "r", encoding="utf-8") as f:
# Load all non-empty lines and unescape newlines
self.input_history = [
line.strip().replace("\\n", "\n")
for line in f
if line.strip()
]
except Exception as e:
self.logger.error(f"Failed to load input history: {e}")
def _save_input_to_history(self, user_input: str) -> None:
"""Append input to history.txt file."""
from oai.constants import HISTORY_FILE
try:
# Escape newlines so multiline inputs stay as one history entry
escaped_input = user_input.replace("\n", "\\n")
with open(HISTORY_FILE, "a", encoding="utf-8") as f:
f.write(f"{escaped_input}\n")
except Exception as e:
self.logger.error(f"Failed to save input to history: {e}")
def on_key(self, event) -> None: def on_key(self, event) -> None:
"""Handle global keyboard shortcuts.""" """Handle global keyboard shortcuts."""
# Debug: Show what key was pressed # Debug: Show what key was pressed
# self.notify(f"Key pressed: {event.key}", severity="information") # self.notify(f"Key pressed: {event.key}", severity="information")
# Handle Escape to cancel generation
if event.key == "escape" and self._is_generating:
self._cancel_generation = True
self.notify("⏹️ Stopping generation...", severity="warning")
event.prevent_default()
event.stop()
return
# Don't handle keys if a modal screen is open (let the modal handle them) # Don't handle keys if a modal screen is open (let the modal handle them)
if len(self.screen_stack) > 1: if len(self.screen_stack) > 1:
return return
@@ -119,25 +161,29 @@ class oAIChatApp(App):
if dropdown_visible: if dropdown_visible:
if event.key == "up": if event.key == "up":
event.prevent_default() event.prevent_default()
event.stop()
dropdown.move_selection_up() dropdown.move_selection_up()
return return
elif event.key == "down": elif event.key == "down":
event.prevent_default() event.prevent_default()
event.stop()
dropdown.move_selection_down() dropdown.move_selection_down()
return return
elif event.key == "tab": elif event.key == "tab":
# Tab accepts the selected command and adds space for arguments # Tab accepts the selected command and adds space for arguments
event.prevent_default() event.prevent_default()
event.stop()
selected = dropdown.get_selected_command() selected = dropdown.get_selected_command()
if selected: if selected:
chat_input.value = selected + " " chat_input.text = selected + " "
chat_input.cursor_position = len(chat_input.value) chat_input.move_cursor_relative(rows=0, columns=len(selected) + 1)
dropdown.hide() dropdown.hide()
return return
elif event.key == "enter": elif event.key == "enter":
# Enter accepts the selected command # Enter accepts the selected command
# If command needs more input, add space; otherwise submit # If command needs more input, add space; otherwise submit
event.prevent_default() event.prevent_default()
event.stop() # Stop propagation to prevent TextArea from processing
selected = dropdown.get_selected_command() selected = dropdown.get_selected_command()
if selected: if selected:
# Commands that require additional arguments # Commands that require additional arguments
@@ -155,13 +201,13 @@ class oAIChatApp(App):
if needs_input: if needs_input:
# Add space and wait for user to type more # Add space and wait for user to type more
chat_input.value = selected + " " chat_input.text = selected + " "
chat_input.cursor_position = len(chat_input.value) chat_input.move_cursor_relative(rows=0, columns=len(selected) + 1)
dropdown.hide() dropdown.hide()
else: else:
# Command is complete, submit it directly # Command is complete, submit it directly
dropdown.hide() dropdown.hide()
chat_input.value = "" # Clear immediately chat_input.clear() # Clear immediately
# Process the command directly # Process the command directly
async def submit_command(): async def submit_command():
await self._process_submitted_input(selected) await self._process_submitted_input(selected)
@@ -170,8 +216,14 @@ class oAIChatApp(App):
elif event.key == "escape": elif event.key == "escape":
# Escape closes dropdown # Escape closes dropdown
event.prevent_default() event.prevent_default()
event.stop()
dropdown.hide() dropdown.hide()
return return
# Escape to clear input (when dropdown not visible)
elif event.key == "escape":
event.prevent_default()
chat_input.clear()
return
# Otherwise, arrow keys navigate history # Otherwise, arrow keys navigate history
elif event.key == "up": elif event.key == "up":
event.prevent_default() event.prevent_default()
@@ -211,9 +263,9 @@ class oAIChatApp(App):
event.prevent_default() event.prevent_default()
self.action_copy_last_response() self.action_copy_last_response()
def on_input_changed(self, event: Input.Changed) -> None: def on_text_area_changed(self, event: TextArea.Changed) -> None:
"""Handle input value changes to show/hide command dropdown.""" """Handle text area value changes to show/hide command dropdown."""
if event.input.id != "chat-input": if event.text_area.id != "chat-input":
return return
# Don't show dropdown when navigating history # Don't show dropdown when navigating history
@@ -221,7 +273,7 @@ class oAIChatApp(App):
return return
dropdown = self.query_one(CommandDropdown) dropdown = self.query_one(CommandDropdown)
value = event.value value = event.text_area.text
# Show dropdown if input starts with / # Show dropdown if input starts with /
if value.startswith("/") and not value.startswith("//"): if value.startswith("/") and not value.startswith("//"):
@@ -229,18 +281,22 @@ class oAIChatApp(App):
else: else:
dropdown.hide() dropdown.hide()
async def on_input_submitted(self, event: Input.Submitted) -> None: async def on_chat_text_area_submit(self, event: ChatTextArea.Submit) -> None:
"""Handle input submission.""" """Handle Enter key submission from ChatTextArea."""
user_input = event.value.strip() user_input = event.value.strip()
if not user_input: if not user_input:
return return
# Clear input field immediately # Clear input field and refocus for next message
event.input.value = "" input_bar = self.query_one(InputBar)
chat_input = input_bar.get_input()
chat_input.clear()
# Process the input (async, will wait for AI response)
await self._process_submitted_input(user_input) await self._process_submitted_input(user_input)
# Keep input focused so user can type while AI responds
chat_input.focus()
async def _process_submitted_input(self, user_input: str) -> None: async def _process_submitted_input(self, user_input: str) -> None:
"""Process submitted input (command or message). """Process submitted input (command or message).
@@ -254,20 +310,30 @@ class oAIChatApp(App):
dropdown = self.query_one(CommandDropdown) dropdown = self.query_one(CommandDropdown)
dropdown.hide() dropdown.hide()
# Add to history # Add to in-memory history and save to file
self.input_history.append(user_input) self.input_history.append(user_input)
self.history_index = -1 self.history_index = -1
self._save_input_to_history(user_input)
# Always show what the user typed # Show user message
chat_display = self.query_one(ChatDisplay) chat_display = self.query_one(ChatDisplay)
user_widget = UserMessageWidget(user_input) user_widget = UserMessageWidget(user_input)
await chat_display.add_message(user_widget) await chat_display.add_message(user_widget)
# Check if it's a command # Check if it's a command or message
if user_input.startswith("/"): if user_input.startswith("/"):
await self.handle_command(user_input) # Defer command processing until after the UI renders
self.call_after_refresh(lambda: self.handle_command(user_input))
else: else:
await self.handle_message(user_input) # Defer assistant widget and AI call until after UI renders
async def setup_and_call_ai():
model_name = self.session.selected_model.get("name", "Assistant") if self.session.selected_model else "Assistant"
assistant_widget = AssistantMessageWidget(model_name, chat_display=chat_display)
await chat_display.add_message(assistant_widget)
assistant_widget.set_content("_Thinking... (Press Esc to stop)_")
await self.handle_message(user_input, assistant_widget)
self.call_after_refresh(setup_and_call_ai)
async def handle_command(self, command_text: str) -> None: async def handle_command(self, command_text: str) -> None:
"""Handle a slash command.""" """Handle a slash command."""
@@ -449,6 +515,11 @@ class oAIChatApp(App):
chat_display = self.query_one(ChatDisplay) chat_display = self.query_one(ChatDisplay)
error_widget = SystemMessageWidget(f"{result.message}") error_widget = SystemMessageWidget(f"{result.message}")
await chat_display.add_message(error_widget) await chat_display.add_message(error_widget)
elif result.status == CommandStatus.WARNING:
# Display warning in chat
chat_display = self.query_one(ChatDisplay)
warning_widget = SystemMessageWidget(f"⚠️ {result.message}")
await chat_display.add_message(warning_widget)
elif result.message: elif result.message:
# Display success message # Display success message
chat_display = self.query_one(ChatDisplay) chat_display = self.query_one(ChatDisplay)
@@ -486,61 +557,78 @@ class oAIChatApp(App):
# Update online mode indicator # Update online mode indicator
input_bar.update_online_mode(self.session.online_enabled) input_bar.update_online_mode(self.session.online_enabled)
async def handle_message(self, user_input: str) -> None: async def handle_message(self, user_input: str, assistant_widget: AssistantMessageWidget = None) -> None:
"""Handle a chat message (user message already added by caller).""" """Handle a chat message (user message and assistant widget already added by caller)."""
chat_display = self.query_one(ChatDisplay) chat_display = self.query_one(ChatDisplay)
# Create assistant message widget with loading indicator # If no assistant widget provided (legacy), create one
model_name = self.session.selected_model.get("name", "Assistant") if self.session.selected_model else "Assistant" if assistant_widget is None:
assistant_widget = AssistantMessageWidget(model_name, chat_display=chat_display) model_name = self.session.selected_model.get("name", "Assistant") if self.session.selected_model else "Assistant"
await chat_display.add_message(assistant_widget) assistant_widget = AssistantMessageWidget(model_name, chat_display=chat_display)
await chat_display.add_message(assistant_widget)
assistant_widget.set_content("_Thinking... (Press Esc to stop)_")
# Show loading indicator immediately self._is_generating = True
assistant_widget.set_content("_Thinking..._") self._cancel_generation = False
try: # Run streaming in background to keep UI responsive
# Stream response async def stream_task():
response_iterator = self.session.send_message_async( try:
user_input, # Stream response
stream=self.settings.stream_enabled, response_iterator = self.session.send_message_async(
) user_input,
stream=self.settings.stream_enabled,
# Stream and collect response
full_text, usage = await assistant_widget.stream_response(response_iterator)
# Add to history if we got a response
if full_text:
# Extract cost from usage or calculate from pricing
cost = 0.0
if usage and hasattr(usage, 'total_cost_usd') and usage.total_cost_usd:
cost = usage.total_cost_usd
self.notify(f"Cost from API: ${cost:.6f}", severity="information")
elif usage and self.session.selected_model:
# Calculate cost from model pricing
pricing = self.session.selected_model.get("pricing", {})
prompt_cost = float(pricing.get("prompt", 0))
completion_cost = float(pricing.get("completion", 0))
# Prices are per token, convert to dollars
prompt_total = usage.prompt_tokens * prompt_cost
completion_total = usage.completion_tokens * completion_cost
cost = prompt_total + completion_total
if cost > 0:
self.notify(f"Cost calculated: ${cost:.6f}", severity="information")
self.session.add_to_history(
prompt=user_input,
response=full_text,
usage=usage,
cost=cost,
) )
# Update footer # Stream and collect response with cancellation support
self._update_footer() full_text, usage = await assistant_widget.stream_response(
response_iterator,
cancel_check=lambda: self._cancel_generation
)
except Exception as e: # Add to history if we got a response
assistant_widget.set_content(f"❌ Error: {str(e)}") if full_text:
# Extract cost from usage or calculate from pricing
cost = 0.0
if usage and hasattr(usage, 'total_cost_usd') and usage.total_cost_usd:
cost = usage.total_cost_usd
self.notify(f"Cost from API: ${cost:.6f}", severity="information")
elif usage and self.session.selected_model:
# Calculate cost from model pricing
pricing = self.session.selected_model.get("pricing", {})
prompt_cost = float(pricing.get("prompt", 0))
completion_cost = float(pricing.get("completion", 0))
# Prices are per token, convert to dollars
prompt_total = usage.prompt_tokens * prompt_cost
completion_total = usage.completion_tokens * completion_cost
cost = prompt_total + completion_total
if cost > 0:
self.notify(f"Cost calculated: ${cost:.6f}", severity="information")
self.session.add_to_history(
prompt=user_input,
response=full_text,
usage=usage,
cost=cost,
)
# Update footer
self._update_footer()
# Check if generation was cancelled
if self._cancel_generation and full_text:
assistant_widget.set_content(full_text + "\n\n_[Generation stopped by user]_")
except Exception as e:
assistant_widget.set_content(f"❌ Error: {str(e)}")
finally:
self._is_generating = False
self._cancel_generation = False
# Create background task - don't await it!
asyncio.create_task(stream_task())
def _update_footer(self) -> None: def _update_footer(self) -> None:
"""Update footer statistics.""" """Update footer statistics."""
@@ -860,9 +948,8 @@ class oAIChatApp(App):
elif "retry_prompt" in data: elif "retry_prompt" in data:
await self.handle_message(data["retry_prompt"]) await self.handle_message(data["retry_prompt"])
# Paste prompt # Paste command is disabled - users should use Cmd+V/Ctrl+V instead
elif "paste_prompt" in data: # No special handling needed
await self.handle_message(data["paste_prompt"])
def _show_model_selector(self, search: str = "", set_as_default: bool = False) -> None: def _show_model_selector(self, search: str = "", set_as_default: bool = False) -> None:
"""Show the model selector screen.""" """Show the model selector screen."""
@@ -994,7 +1081,7 @@ class oAIChatApp(App):
callback=handle_confirmation callback=handle_confirmation
) )
def _navigate_history_backward(self, input_widget: Input) -> None: def _navigate_history_backward(self, input_widget: TextArea) -> None:
"""Navigate backward through input history (Up arrow).""" """Navigate backward through input history (Up arrow)."""
if not self.input_history: if not self.input_history:
return return
@@ -1011,14 +1098,14 @@ class oAIChatApp(App):
# Update input with history item # Update input with history item
if 0 <= self.history_index < len(self.input_history): if 0 <= self.history_index < len(self.input_history):
input_widget.value = self.input_history[self.history_index] input_widget.text = self.input_history[self.history_index]
# Move cursor to end # Move cursor to end
input_widget.cursor_position = len(input_widget.value) input_widget.move_cursor((999, 999)) # Move to end
# Clear flag after a short delay # Clear flag after a short delay
self.set_timer(0.1, lambda: setattr(self, "_navigating_history", False)) self.set_timer(0.1, lambda: setattr(self, "_navigating_history", False))
def _navigate_history_forward(self, input_widget: Input) -> None: def _navigate_history_forward(self, input_widget: TextArea) -> None:
"""Navigate forward through input history (Down arrow).""" """Navigate forward through input history (Down arrow)."""
if not self.input_history or self.history_index == -1: if not self.input_history or self.history_index == -1:
return return
@@ -1029,12 +1116,12 @@ class oAIChatApp(App):
# Move forward in history # Move forward in history
if self.history_index < len(self.input_history) - 1: if self.history_index < len(self.input_history) - 1:
self.history_index += 1 self.history_index += 1
input_widget.value = self.input_history[self.history_index] input_widget.text = self.input_history[self.history_index]
input_widget.cursor_position = len(input_widget.value) input_widget.move_cursor((999, 999)) # Move to end
else: else:
# At the newest item, clear the input # At the newest item, clear the input
self.history_index = -1 self.history_index = -1
input_widget.value = "" input_widget.clear()
# Clear flag after a short delay # Clear flag after a short delay
self.set_timer(0.1, lambda: setattr(self, "_navigating_history", False)) self.set_timer(0.1, lambda: setattr(self, "_navigating_history", False))

View File

@@ -111,6 +111,7 @@ class CommandsScreen(ModalScreen[None]):
[green]/config model <id>[/] - Set default model [green]/config model <id>[/] - Set default model
[green]/config system <prompt>[/] - Set system prompt [green]/config system <prompt>[/] - Set system prompt
[green]/config maxtoken <num>[/] - Set token limit [green]/config maxtoken <num>[/] - Set token limit
[green]/config log <level>[/] - Set log level (debug/info/warning/error/critical)
[bold cyan]Memory & Context[/] [bold cyan]Memory & Context[/]

View File

@@ -64,6 +64,9 @@ class HelpScreen(ModalScreen[None]):
[bold]F1[/] Show this help (Ctrl+H may not work) [bold]F1[/] Show this help (Ctrl+H may not work)
[bold]F2[/] Open model selector (Ctrl+M may not work) [bold]F2[/] Open model selector (Ctrl+M may not work)
[bold]F3[/] Copy last AI response to clipboard [bold]F3[/] Copy last AI response to clipboard
[bold]Enter[/] Submit message
[bold]Ctrl+Enter[/] Insert newline (for multiline messages)
[bold]Esc[/] Stop/cancel AI response generation
[bold]Ctrl+S[/] Show session statistics [bold]Ctrl+S[/] Show session statistics
[bold]Ctrl+L[/] Clear chat display [bold]Ctrl+L[/] Clear chat display
[bold]Ctrl+P[/] Show previous message [bold]Ctrl+P[/] Show previous message
@@ -71,7 +74,7 @@ class HelpScreen(ModalScreen[None]):
[bold]Ctrl+Y[/] Copy last AI response (alternative to F3) [bold]Ctrl+Y[/] Copy last AI response (alternative to F3)
[bold]Ctrl+Q[/] Quit application [bold]Ctrl+Q[/] Quit application
[bold]Up/Down[/] Navigate input history [bold]Up/Down[/] Navigate input history
[bold]ESC[/] Close dialogs [bold]ESC[/] Clear input / Close dialogs
[dim]Note: Some Ctrl keys may be captured by your terminal[/] [dim]Note: Some Ctrl keys may be captured by your terminal[/]
[bold cyan]═══ SLASH COMMANDS ═══[/] [bold cyan]═══ SLASH COMMANDS ═══[/]
@@ -89,6 +92,7 @@ class HelpScreen(ModalScreen[None]):
/config stream on Enable streaming responses /config stream on Enable streaming responses
/system [prompt] Set session system prompt /system [prompt] Set session system prompt
/maxtoken [n] Set session token limit /maxtoken [n] Set session token limit
/config log [level] Set log level (debug/info/warning/error/critical)
[bold yellow]Conversation Management:[/] [bold yellow]Conversation Management:[/]
/save [name] Save current conversation /save [name] Save current conversation

View File

@@ -13,6 +13,7 @@ class ChatDisplay(ScrollableContainer):
async def add_message(self, widget: Static) -> None: async def add_message(self, widget: Static) -> None:
"""Add a message widget to the display.""" """Add a message widget to the display."""
await self.mount(widget) await self.mount(widget)
self.refresh(layout=True)
self.scroll_end(animate=False) self.scroll_end(animate=False)
def clear_messages(self) -> None: def clear_messages(self) -> None:

View File

@@ -98,6 +98,7 @@ class CommandDropdown(VerticalScroll):
("/config model", "Set default model"), ("/config model", "Set default model"),
("/config system", "Set system prompt"), ("/config system", "Set system prompt"),
("/config maxtoken", "Set token limit"), ("/config maxtoken", "Set token limit"),
("/config log", "Set log level (debug/info/warning/error)"),
("/system", "Set system prompt"), ("/system", "Set system prompt"),
("/maxtoken", "Set token limit"), ("/maxtoken", "Set token limit"),
("/retry", "Retry last prompt"), ("/retry", "Retry last prompt"),

View File

@@ -2,7 +2,53 @@
from textual.app import ComposeResult from textual.app import ComposeResult
from textual.containers import Horizontal from textual.containers import Horizontal
from textual.widgets import Input, Static from textual.message import Message
from textual.widgets import Static, TextArea
class ChatTextArea(TextArea):
"""Custom TextArea that sends submit message on Enter (unless dropdown is open)."""
class Submit(Message):
"""Message sent when Enter is pressed."""
def __init__(self, value: str) -> None:
super().__init__()
self.value = value
def _on_key(self, event) -> None:
"""Handle key events BEFORE TextArea processes them."""
# Check if command dropdown is visible
dropdown_visible = False
try:
from oai.tui.widgets.command_dropdown import CommandDropdown
dropdown = self.app.query_one(CommandDropdown)
dropdown_visible = dropdown.has_class("visible")
except:
pass
if event.key == "enter":
if dropdown_visible:
# Dropdown is open - prevent TextArea from inserting newline
# but let event bubble up to app for dropdown handling
event.prevent_default()
# Don't call stop() - let it bubble to app's on_key
# Don't call super - we don't want newline
return
else:
# Dropdown not visible - submit the message
event.prevent_default()
event.stop()
self.post_message(self.Submit(self.text))
return
elif event.key in ("ctrl+j", "ctrl+enter"):
# Insert newline on Ctrl+Enter
event.prevent_default()
event.stop()
self.insert("\n")
return
# For all other keys, let TextArea handle them normally
super()._on_key(event)
class InputBar(Horizontal): class InputBar(Horizontal):
@@ -16,10 +62,9 @@ class InputBar(Horizontal):
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
"""Compose the input bar.""" """Compose the input bar."""
yield Static(self._format_prefix(), id="input-prefix", classes="prefix-hidden" if not (self.mcp_status or self.online_mode) else "") yield Static(self._format_prefix(), id="input-prefix", classes="prefix-hidden" if not (self.mcp_status or self.online_mode) else "")
yield Input( text_area = ChatTextArea(id="chat-input")
placeholder="Type a message or /command...", text_area.show_line_numbers = False
id="chat-input" yield text_area
)
def _format_prefix(self) -> str: def _format_prefix(self) -> str:
"""Format the input prefix with status indicators.""" """Format the input prefix with status indicators."""
@@ -44,6 +89,6 @@ class InputBar(Horizontal):
prefix = self.query_one("#input-prefix", Static) prefix = self.query_one("#input-prefix", Static)
prefix.update(self._format_prefix()) prefix.update(self._format_prefix())
def get_input(self) -> Input: def get_input(self) -> ChatTextArea:
"""Get the input widget.""" """Get the input widget."""
return self.query_one("#chat-input", Input) return self.query_one("#chat-input", ChatTextArea)

View File

@@ -1,5 +1,6 @@
"""Message widgets for oAI TUI.""" """Message widgets for oAI TUI."""
import asyncio
from typing import Any, AsyncIterator, Tuple from typing import Any, AsyncIterator, Tuple
from rich.console import Console from rich.console import Console
@@ -35,7 +36,13 @@ class UserMessageWidget(Static):
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
"""Compose the user message.""" """Compose the user message."""
yield Static(f"[bold green]You:[/] {self.content}") yield Static(f"[bold green]You:[/] {self.content}", id="user-message-content")
def update_content(self, new_content: str) -> None:
"""Update the message content."""
self.content = new_content
content_widget = self.query_one("#user-message-content", Static)
content_widget.update(f"[bold green]You:[/] {new_content}")
class SystemMessageWidget(Static): class SystemMessageWidget(Static):
@@ -64,13 +71,22 @@ class AssistantMessageWidget(Static):
yield Static(f"[bold]{self.model_name}:[/]", id="assistant-label") yield Static(f"[bold]{self.model_name}:[/]", id="assistant-label")
yield RichLog(id="assistant-content", highlight=True, markup=True, wrap=True) yield RichLog(id="assistant-content", highlight=True, markup=True, wrap=True)
async def stream_response(self, response_iterator: AsyncIterator) -> Tuple[str, Any]: async def stream_response(self, response_iterator: AsyncIterator, cancel_check=None) -> Tuple[str, Any]:
"""Stream tokens progressively and return final text and usage.""" """Stream tokens progressively and return final text and usage.
Args:
response_iterator: Async iterator of response chunks
cancel_check: Optional callable that returns True if generation should be cancelled
"""
log = self.query_one("#assistant-content", RichLog) log = self.query_one("#assistant-content", RichLog)
self.full_text = "" self.full_text = ""
usage = None usage = None
async for chunk in response_iterator: async for chunk in response_iterator:
# Check for cancellation
if cancel_check and cancel_check():
break
if hasattr(chunk, "delta_content") and chunk.delta_content: if hasattr(chunk, "delta_content") and chunk.delta_content:
self.full_text += chunk.delta_content self.full_text += chunk.delta_content
log.clear() log.clear()
@@ -86,6 +102,9 @@ class AssistantMessageWidget(Static):
if hasattr(chunk, "usage") and chunk.usage: if hasattr(chunk, "usage") and chunk.usage:
usage = chunk.usage usage = chunk.usage
# Yield control to event loop so UI stays responsive
await asyncio.sleep(0)
return self.full_text, usage return self.full_text, usage
def set_content(self, content: str) -> None: def set_content(self, content: str) -> None:

View File

@@ -278,11 +278,16 @@ def prepare_file_attachment(
file_data = f.read() file_data = f.read()
if category == "image": if category == "image":
# Check if model supports images # Check if model supports images - try multiple possible locations
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", []) input_modalities = (
if "image" not in input_modalities: model_capabilities.get("input_modalities", []) or
logger.warning(f"Model does not support images") model_capabilities.get("architecture", {}).get("input_modalities", [])
return None )
# If no input_modalities found or image not in list, try to attach anyway
# Some models support images but don't advertise it properly
if input_modalities and "image" not in input_modalities:
logger.warning(f"Model may not support images, attempting anyway...")
b64_data = base64.b64encode(file_data).decode("utf-8") b64_data = base64.b64encode(file_data).decode("utf-8")
return { return {
@@ -291,12 +296,16 @@ def prepare_file_attachment(
} }
elif category == "pdf": elif category == "pdf":
# Check if model supports PDFs # Check if model supports PDFs - try multiple possible locations
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", []) input_modalities = (
model_capabilities.get("input_modalities", []) or
model_capabilities.get("architecture", {}).get("input_modalities", [])
)
supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"]) supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"])
if not supports_pdf:
logger.warning(f"Model does not support PDFs") if input_modalities and not supports_pdf:
return None logger.warning(f"Model may not support PDFs, attempting anyway...")
b64_data = base64.b64encode(file_data).decode("utf-8") b64_data = base64.b64encode(file_data).decode("utf-8")
return { return {
@@ -321,3 +330,49 @@ def prepare_file_attachment(
except Exception as e: except Exception as e:
logger.error(f"Error preparing file attachment {path}: {e}") logger.error(f"Error preparing file attachment {path}: {e}")
return None return None
def parse_file_attachments(user_input: str) -> Tuple[str, list[Path]]:
"""
Parse user input for @<file> or @file syntax and extract file paths.
Args:
user_input: User's message that may contain @<file> or @file references
Returns:
Tuple of (cleaned_text, list_of_file_paths)
Example:
>>> parse_file_attachments("Look at @<image.png> and @/path/to/doc.pdf")
("Look at and ", [Path("image.png"), Path("/path/to/doc.pdf")])
"""
import re
logger = get_logger()
# Pattern to match both @<filepath> and @filepath
# Matches @<...> or @followed by path-like strings
pattern = r'@<([^>]+)>|@([/~.][\S]+|[a-zA-Z]:[/\\][\S]+)'
# Find all file references
file_paths = []
matches_to_remove = []
for match in re.finditer(pattern, user_input):
# Group 1 is for @<filepath>, Group 2 is for @filepath
file_path_str = (match.group(1) or match.group(2)).strip()
file_path = Path(file_path_str).expanduser().resolve()
if file_path.exists():
file_paths.append(file_path)
logger.info(f"Found file attachment: {file_path}")
matches_to_remove.append(match.group(0))
else:
logger.warning(f"File not found: {file_path_str}")
matches_to_remove.append(match.group(0))
# Remove @<file> and @file references from the text
cleaned_text = user_input
for match_str in matches_to_remove:
cleaned_text = cleaned_text.replace(match_str, '', 1)
return cleaned_text, file_paths

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "oai" name = "oai"
version = "3.0.0-b3" # MUST match oai/__init__.py __version__ version = "3.0.0-b4" # MUST match oai/__init__.py __version__
description = "Open AI Chat Client - Multi-provider terminal chat with MCP support" description = "Open AI Chat Client - Multi-provider terminal chat with MCP support"
readme = "README.md" readme = "README.md"
license = {text = "MIT"} license = {text = "MIT"}