Files
oai/oai/tui/screens/model_selector.py
2026-02-04 11:22:53 +01:00

255 lines
8.9 KiB
Python

"""Model selector screen for oAI TUI."""
from typing import List, Optional
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, DataTable, Input, Label, Static
class ModelSelectorScreen(ModalScreen[Optional[dict]]):
"""Modal screen for selecting an AI model."""
DEFAULT_CSS = """
ModelSelectorScreen {
align: center middle;
}
ModelSelectorScreen > Container {
width: 90%;
height: 85%;
background: #1e1e1e;
border: solid #555555;
layout: vertical;
}
ModelSelectorScreen .header {
height: 3;
width: 100%;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
content-align: center middle;
}
ModelSelectorScreen .search-input {
height: 3;
width: 100%;
background: #2a2a2a;
border: solid #555555;
margin: 0 0 1 0;
}
ModelSelectorScreen .search-input:focus {
border: solid #888888;
}
ModelSelectorScreen DataTable {
height: 1fr;
width: 100%;
background: #1e1e1e;
border: solid #555555;
}
ModelSelectorScreen .footer {
height: 5;
width: 100%;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
ModelSelectorScreen Button {
margin: 0 1;
}
"""
def __init__(self, models: List[dict], current_model: Optional[str] = None):
super().__init__()
self.all_models = models
self.filtered_models = models
self.current_model = current_model
self.selected_model: Optional[dict] = None
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static(
f"[bold]Select Model[/] [dim]({len(self.all_models)} available)[/]",
classes="header"
)
yield Input(placeholder="Search to filter models...", id="search-input", classes="search-input")
yield DataTable(id="model-table", cursor_type="row", show_header=True, zebra_stripes=True)
with Vertical(classes="footer"):
yield Button("Select", id="select", variant="success")
yield Button("Cancel", id="cancel", variant="error")
def on_mount(self) -> None:
"""Initialize the table when mounted."""
table = self.query_one("#model-table", DataTable)
# Add columns
table.add_column("#", width=5)
table.add_column("Model ID", width=35)
table.add_column("Name", width=30)
table.add_column("Context", width=10)
table.add_column("Price", width=12)
table.add_column("Img", width=4)
table.add_column("Tools", width=6)
table.add_column("Online", width=7)
# Populate table
self._populate_table()
# Focus table if list is small (fits on screen), otherwise focus search
if len(self.filtered_models) <= 20:
table.focus()
else:
search_input = self.query_one("#search-input", Input)
search_input.focus()
def _populate_table(self) -> None:
"""Populate the table with models."""
table = self.query_one("#model-table", DataTable)
table.clear()
rows_added = 0
for idx, model in enumerate(self.filtered_models, 1):
try:
model_id = model.get("id", "")
name = model.get("name", "")
context = str(model.get("context_length", "N/A"))
# Format pricing
pricing = model.get("pricing", {})
prompt_price = pricing.get("prompt", "0")
completion_price = pricing.get("completion", "0")
# Convert to numbers and format
try:
prompt = float(prompt_price) * 1000000 # Convert to per 1M tokens
completion = float(completion_price) * 1000000
if prompt == 0 and completion == 0:
price = "Free"
else:
price = f"${prompt:.2f}/${completion:.2f}"
except:
price = "N/A"
# Check capabilities
architecture = model.get("architecture", {})
modality = architecture.get("modality", "")
supported_params = model.get("supported_parameters", [])
# Vision support: check if modality contains "image"
supports_vision = "image" in modality
# Tool support: check if "tools" or "tool_choice" in supported_parameters
supports_tools = "tools" in supported_params or "tool_choice" in supported_params
# Online support: check if model can use :online suffix (most models can)
# Models that already have :online in their ID support it
supports_online = ":online" in model_id or model_id not in ["openrouter/free"]
# Format capability indicators
img_indicator = "" if supports_vision else "-"
tools_indicator = "" if supports_tools else "-"
web_indicator = "" if supports_online else "-"
# Add row
table.add_row(
str(idx),
model_id,
name,
context,
price,
img_indicator,
tools_indicator,
web_indicator,
key=str(idx)
)
rows_added += 1
except Exception:
# Silently skip rows that fail to add
pass
def on_input_changed(self, event: Input.Changed) -> None:
"""Filter models based on search input."""
if event.input.id != "search-input":
return
search_term = event.value.lower()
if not search_term:
self.filtered_models = self.all_models
else:
self.filtered_models = [
m for m in self.all_models
if search_term in m.get("id", "").lower()
or search_term in m.get("name", "").lower()
]
self._populate_table()
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
"""Handle row selection (click or arrow navigation)."""
try:
row_index = int(event.row_key.value) - 1
if 0 <= row_index < len(self.filtered_models):
self.selected_model = self.filtered_models[row_index]
except (ValueError, IndexError):
pass
def on_data_table_row_highlighted(self, event) -> None:
"""Handle row highlight (arrow key navigation)."""
try:
table = self.query_one("#model-table", DataTable)
if table.cursor_row is not None:
row_data = table.get_row_at(table.cursor_row)
if row_data:
row_index = int(row_data[0]) - 1
if 0 <= row_index < len(self.filtered_models):
self.selected_model = self.filtered_models[row_index]
except (ValueError, IndexError, AttributeError):
pass
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
if event.button.id == "select":
if self.selected_model:
self.dismiss(self.selected_model)
else:
# No selection, dismiss without result
self.dismiss(None)
else:
self.dismiss(None)
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key == "escape":
self.dismiss(None)
elif event.key == "enter":
# If in search input, move to table
search_input = self.query_one("#search-input", Input)
if search_input.has_focus:
table = self.query_one("#model-table", DataTable)
table.focus()
# If in table or anywhere else, select current row
else:
table = self.query_one("#model-table", DataTable)
# Get the currently highlighted row
if table.cursor_row is not None:
try:
row_key = table.get_row_at(table.cursor_row)
if row_key:
row_index = int(row_key[0]) - 1
if 0 <= row_index < len(self.filtered_models):
selected = self.filtered_models[row_index]
self.dismiss(selected)
except (ValueError, IndexError, AttributeError):
# Fall back to previously selected model
if self.selected_model:
self.dismiss(self.selected_model)