255 lines
8.9 KiB
Python
255 lines
8.9 KiB
Python
"""Model selector screen for oAI TUI."""
|
|
|
|
from typing import List, Optional
|
|
|
|
from textual.app import ComposeResult
|
|
from textual.containers import Container, Vertical
|
|
from textual.screen import ModalScreen
|
|
from textual.widgets import Button, DataTable, Input, Label, Static
|
|
|
|
|
|
class ModelSelectorScreen(ModalScreen[Optional[dict]]):
|
|
"""Modal screen for selecting an AI model."""
|
|
|
|
DEFAULT_CSS = """
|
|
ModelSelectorScreen {
|
|
align: center middle;
|
|
}
|
|
|
|
ModelSelectorScreen > Container {
|
|
width: 90%;
|
|
height: 85%;
|
|
background: #1e1e1e;
|
|
border: solid #555555;
|
|
layout: vertical;
|
|
}
|
|
|
|
ModelSelectorScreen .header {
|
|
height: 3;
|
|
width: 100%;
|
|
background: #2d2d2d;
|
|
color: #cccccc;
|
|
padding: 0 2;
|
|
content-align: center middle;
|
|
}
|
|
|
|
ModelSelectorScreen .search-input {
|
|
height: 3;
|
|
width: 100%;
|
|
background: #2a2a2a;
|
|
border: solid #555555;
|
|
margin: 0 0 1 0;
|
|
}
|
|
|
|
ModelSelectorScreen .search-input:focus {
|
|
border: solid #888888;
|
|
}
|
|
|
|
ModelSelectorScreen DataTable {
|
|
height: 1fr;
|
|
width: 100%;
|
|
background: #1e1e1e;
|
|
border: solid #555555;
|
|
}
|
|
|
|
ModelSelectorScreen .footer {
|
|
height: 5;
|
|
width: 100%;
|
|
background: #2d2d2d;
|
|
padding: 1 2;
|
|
align: center middle;
|
|
}
|
|
|
|
ModelSelectorScreen Button {
|
|
margin: 0 1;
|
|
}
|
|
"""
|
|
|
|
def __init__(self, models: List[dict], current_model: Optional[str] = None):
|
|
super().__init__()
|
|
self.all_models = models
|
|
self.filtered_models = models
|
|
self.current_model = current_model
|
|
self.selected_model: Optional[dict] = None
|
|
|
|
def compose(self) -> ComposeResult:
|
|
"""Compose the screen."""
|
|
with Container():
|
|
yield Static(
|
|
f"[bold]Select Model[/] [dim]({len(self.all_models)} available)[/]",
|
|
classes="header"
|
|
)
|
|
yield Input(placeholder="Search to filter models...", id="search-input", classes="search-input")
|
|
yield DataTable(id="model-table", cursor_type="row", show_header=True, zebra_stripes=True)
|
|
with Vertical(classes="footer"):
|
|
yield Button("Select", id="select", variant="success")
|
|
yield Button("Cancel", id="cancel", variant="error")
|
|
|
|
def on_mount(self) -> None:
|
|
"""Initialize the table when mounted."""
|
|
table = self.query_one("#model-table", DataTable)
|
|
|
|
# Add columns
|
|
table.add_column("#", width=5)
|
|
table.add_column("Model ID", width=35)
|
|
table.add_column("Name", width=30)
|
|
table.add_column("Context", width=10)
|
|
table.add_column("Price", width=12)
|
|
table.add_column("Img", width=4)
|
|
table.add_column("Tools", width=6)
|
|
table.add_column("Online", width=7)
|
|
|
|
# Populate table
|
|
self._populate_table()
|
|
|
|
# Focus table if list is small (fits on screen), otherwise focus search
|
|
if len(self.filtered_models) <= 20:
|
|
table.focus()
|
|
else:
|
|
search_input = self.query_one("#search-input", Input)
|
|
search_input.focus()
|
|
|
|
def _populate_table(self) -> None:
|
|
"""Populate the table with models."""
|
|
table = self.query_one("#model-table", DataTable)
|
|
table.clear()
|
|
|
|
rows_added = 0
|
|
for idx, model in enumerate(self.filtered_models, 1):
|
|
try:
|
|
model_id = model.get("id", "")
|
|
name = model.get("name", "")
|
|
context = str(model.get("context_length", "N/A"))
|
|
|
|
# Format pricing
|
|
pricing = model.get("pricing", {})
|
|
prompt_price = pricing.get("prompt", "0")
|
|
completion_price = pricing.get("completion", "0")
|
|
|
|
# Convert to numbers and format
|
|
try:
|
|
prompt = float(prompt_price) * 1000000 # Convert to per 1M tokens
|
|
completion = float(completion_price) * 1000000
|
|
if prompt == 0 and completion == 0:
|
|
price = "Free"
|
|
else:
|
|
price = f"${prompt:.2f}/${completion:.2f}"
|
|
except:
|
|
price = "N/A"
|
|
|
|
# Check capabilities
|
|
architecture = model.get("architecture", {})
|
|
modality = architecture.get("modality", "")
|
|
supported_params = model.get("supported_parameters", [])
|
|
|
|
# Vision support: check if modality contains "image"
|
|
supports_vision = "image" in modality
|
|
|
|
# Tool support: check if "tools" or "tool_choice" in supported_parameters
|
|
supports_tools = "tools" in supported_params or "tool_choice" in supported_params
|
|
|
|
# Online support: check if model can use :online suffix (most models can)
|
|
# Models that already have :online in their ID support it
|
|
supports_online = ":online" in model_id or model_id not in ["openrouter/free"]
|
|
|
|
# Format capability indicators
|
|
img_indicator = "✓" if supports_vision else "-"
|
|
tools_indicator = "✓" if supports_tools else "-"
|
|
web_indicator = "✓" if supports_online else "-"
|
|
|
|
# Add row
|
|
table.add_row(
|
|
str(idx),
|
|
model_id,
|
|
name,
|
|
context,
|
|
price,
|
|
img_indicator,
|
|
tools_indicator,
|
|
web_indicator,
|
|
key=str(idx)
|
|
)
|
|
rows_added += 1
|
|
|
|
except Exception:
|
|
# Silently skip rows that fail to add
|
|
pass
|
|
|
|
def on_input_changed(self, event: Input.Changed) -> None:
|
|
"""Filter models based on search input."""
|
|
if event.input.id != "search-input":
|
|
return
|
|
|
|
search_term = event.value.lower()
|
|
|
|
if not search_term:
|
|
self.filtered_models = self.all_models
|
|
else:
|
|
self.filtered_models = [
|
|
m for m in self.all_models
|
|
if search_term in m.get("id", "").lower()
|
|
or search_term in m.get("name", "").lower()
|
|
]
|
|
|
|
self._populate_table()
|
|
|
|
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
|
|
"""Handle row selection (click or arrow navigation)."""
|
|
try:
|
|
row_index = int(event.row_key.value) - 1
|
|
if 0 <= row_index < len(self.filtered_models):
|
|
self.selected_model = self.filtered_models[row_index]
|
|
except (ValueError, IndexError):
|
|
pass
|
|
|
|
def on_data_table_row_highlighted(self, event) -> None:
|
|
"""Handle row highlight (arrow key navigation)."""
|
|
try:
|
|
table = self.query_one("#model-table", DataTable)
|
|
if table.cursor_row is not None:
|
|
row_data = table.get_row_at(table.cursor_row)
|
|
if row_data:
|
|
row_index = int(row_data[0]) - 1
|
|
if 0 <= row_index < len(self.filtered_models):
|
|
self.selected_model = self.filtered_models[row_index]
|
|
except (ValueError, IndexError, AttributeError):
|
|
pass
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
"""Handle button press."""
|
|
if event.button.id == "select":
|
|
if self.selected_model:
|
|
self.dismiss(self.selected_model)
|
|
else:
|
|
# No selection, dismiss without result
|
|
self.dismiss(None)
|
|
else:
|
|
self.dismiss(None)
|
|
|
|
def on_key(self, event) -> None:
|
|
"""Handle keyboard shortcuts."""
|
|
if event.key == "escape":
|
|
self.dismiss(None)
|
|
elif event.key == "enter":
|
|
# If in search input, move to table
|
|
search_input = self.query_one("#search-input", Input)
|
|
if search_input.has_focus:
|
|
table = self.query_one("#model-table", DataTable)
|
|
table.focus()
|
|
# If in table or anywhere else, select current row
|
|
else:
|
|
table = self.query_one("#model-table", DataTable)
|
|
# Get the currently highlighted row
|
|
if table.cursor_row is not None:
|
|
try:
|
|
row_key = table.get_row_at(table.cursor_row)
|
|
if row_key:
|
|
row_index = int(row_key[0]) - 1
|
|
if 0 <= row_index < len(self.filtered_models):
|
|
selected = self.filtered_models[row_index]
|
|
self.dismiss(selected)
|
|
except (ValueError, IndexError, AttributeError):
|
|
# Fall back to previously selected model
|
|
if self.selected_model:
|
|
self.dismiss(self.selected_model)
|