oAI version 3.0 beta 1
This commit is contained in:
254
oai/tui/screens/model_selector.py
Normal file
254
oai/tui/screens/model_selector.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Model selector screen for oAI TUI."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Vertical
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, DataTable, Input, Label, Static
|
||||
|
||||
|
||||
class ModelSelectorScreen(ModalScreen[Optional[dict]]):
|
||||
"""Modal screen for selecting an AI model."""
|
||||
|
||||
DEFAULT_CSS = """
|
||||
ModelSelectorScreen {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
ModelSelectorScreen > Container {
|
||||
width: 90%;
|
||||
height: 85%;
|
||||
background: #1e1e1e;
|
||||
border: solid #555555;
|
||||
layout: vertical;
|
||||
}
|
||||
|
||||
ModelSelectorScreen .header {
|
||||
height: 3;
|
||||
width: 100%;
|
||||
background: #2d2d2d;
|
||||
color: #cccccc;
|
||||
padding: 0 2;
|
||||
content-align: center middle;
|
||||
}
|
||||
|
||||
ModelSelectorScreen .search-input {
|
||||
height: 3;
|
||||
width: 100%;
|
||||
background: #2a2a2a;
|
||||
border: solid #555555;
|
||||
margin: 0 0 1 0;
|
||||
}
|
||||
|
||||
ModelSelectorScreen .search-input:focus {
|
||||
border: solid #888888;
|
||||
}
|
||||
|
||||
ModelSelectorScreen DataTable {
|
||||
height: 1fr;
|
||||
width: 100%;
|
||||
background: #1e1e1e;
|
||||
border: solid #555555;
|
||||
}
|
||||
|
||||
ModelSelectorScreen .footer {
|
||||
height: 5;
|
||||
width: 100%;
|
||||
background: #2d2d2d;
|
||||
padding: 1 2;
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
ModelSelectorScreen Button {
|
||||
margin: 0 1;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, models: List[dict], current_model: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.all_models = models
|
||||
self.filtered_models = models
|
||||
self.current_model = current_model
|
||||
self.selected_model: Optional[dict] = None
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Compose the screen."""
|
||||
with Container():
|
||||
yield Static(
|
||||
f"[bold]Select Model[/] [dim]({len(self.all_models)} available)[/]",
|
||||
classes="header"
|
||||
)
|
||||
yield Input(placeholder="Search to filter models...", id="search-input", classes="search-input")
|
||||
yield DataTable(id="model-table", cursor_type="row", show_header=True, zebra_stripes=True)
|
||||
with Vertical(classes="footer"):
|
||||
yield Button("Select", id="select", variant="success")
|
||||
yield Button("Cancel", id="cancel", variant="error")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Initialize the table when mounted."""
|
||||
table = self.query_one("#model-table", DataTable)
|
||||
|
||||
# Add columns
|
||||
table.add_column("#", width=5)
|
||||
table.add_column("Model ID", width=35)
|
||||
table.add_column("Name", width=30)
|
||||
table.add_column("Context", width=10)
|
||||
table.add_column("Price", width=12)
|
||||
table.add_column("Img", width=4)
|
||||
table.add_column("Tools", width=6)
|
||||
table.add_column("Online", width=7)
|
||||
|
||||
# Populate table
|
||||
self._populate_table()
|
||||
|
||||
# Focus table if list is small (fits on screen), otherwise focus search
|
||||
if len(self.filtered_models) <= 20:
|
||||
table.focus()
|
||||
else:
|
||||
search_input = self.query_one("#search-input", Input)
|
||||
search_input.focus()
|
||||
|
||||
def _populate_table(self) -> None:
|
||||
"""Populate the table with models."""
|
||||
table = self.query_one("#model-table", DataTable)
|
||||
table.clear()
|
||||
|
||||
rows_added = 0
|
||||
for idx, model in enumerate(self.filtered_models, 1):
|
||||
try:
|
||||
model_id = model.get("id", "")
|
||||
name = model.get("name", "")
|
||||
context = str(model.get("context_length", "N/A"))
|
||||
|
||||
# Format pricing
|
||||
pricing = model.get("pricing", {})
|
||||
prompt_price = pricing.get("prompt", "0")
|
||||
completion_price = pricing.get("completion", "0")
|
||||
|
||||
# Convert to numbers and format
|
||||
try:
|
||||
prompt = float(prompt_price) * 1000000 # Convert to per 1M tokens
|
||||
completion = float(completion_price) * 1000000
|
||||
if prompt == 0 and completion == 0:
|
||||
price = "Free"
|
||||
else:
|
||||
price = f"${prompt:.2f}/${completion:.2f}"
|
||||
except:
|
||||
price = "N/A"
|
||||
|
||||
# Check capabilities
|
||||
architecture = model.get("architecture", {})
|
||||
modality = architecture.get("modality", "")
|
||||
supported_params = model.get("supported_parameters", [])
|
||||
|
||||
# Vision support: check if modality contains "image"
|
||||
supports_vision = "image" in modality
|
||||
|
||||
# Tool support: check if "tools" or "tool_choice" in supported_parameters
|
||||
supports_tools = "tools" in supported_params or "tool_choice" in supported_params
|
||||
|
||||
# Online support: check if model can use :online suffix (most models can)
|
||||
# Models that already have :online in their ID support it
|
||||
supports_online = ":online" in model_id or model_id not in ["openrouter/free"]
|
||||
|
||||
# Format capability indicators
|
||||
img_indicator = "✓" if supports_vision else "-"
|
||||
tools_indicator = "✓" if supports_tools else "-"
|
||||
web_indicator = "✓" if supports_online else "-"
|
||||
|
||||
# Add row
|
||||
table.add_row(
|
||||
str(idx),
|
||||
model_id,
|
||||
name,
|
||||
context,
|
||||
price,
|
||||
img_indicator,
|
||||
tools_indicator,
|
||||
web_indicator,
|
||||
key=str(idx)
|
||||
)
|
||||
rows_added += 1
|
||||
|
||||
except Exception:
|
||||
# Silently skip rows that fail to add
|
||||
pass
|
||||
|
||||
def on_input_changed(self, event: Input.Changed) -> None:
|
||||
"""Filter models based on search input."""
|
||||
if event.input.id != "search-input":
|
||||
return
|
||||
|
||||
search_term = event.value.lower()
|
||||
|
||||
if not search_term:
|
||||
self.filtered_models = self.all_models
|
||||
else:
|
||||
self.filtered_models = [
|
||||
m for m in self.all_models
|
||||
if search_term in m.get("id", "").lower()
|
||||
or search_term in m.get("name", "").lower()
|
||||
]
|
||||
|
||||
self._populate_table()
|
||||
|
||||
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
|
||||
"""Handle row selection (click or arrow navigation)."""
|
||||
try:
|
||||
row_index = int(event.row_key.value) - 1
|
||||
if 0 <= row_index < len(self.filtered_models):
|
||||
self.selected_model = self.filtered_models[row_index]
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
def on_data_table_row_highlighted(self, event) -> None:
|
||||
"""Handle row highlight (arrow key navigation)."""
|
||||
try:
|
||||
table = self.query_one("#model-table", DataTable)
|
||||
if table.cursor_row is not None:
|
||||
row_data = table.get_row_at(table.cursor_row)
|
||||
if row_data:
|
||||
row_index = int(row_data[0]) - 1
|
||||
if 0 <= row_index < len(self.filtered_models):
|
||||
self.selected_model = self.filtered_models[row_index]
|
||||
except (ValueError, IndexError, AttributeError):
|
||||
pass
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button press."""
|
||||
if event.button.id == "select":
|
||||
if self.selected_model:
|
||||
self.dismiss(self.selected_model)
|
||||
else:
|
||||
# No selection, dismiss without result
|
||||
self.dismiss(None)
|
||||
else:
|
||||
self.dismiss(None)
|
||||
|
||||
def on_key(self, event) -> None:
|
||||
"""Handle keyboard shortcuts."""
|
||||
if event.key == "escape":
|
||||
self.dismiss(None)
|
||||
elif event.key == "enter":
|
||||
# If in search input, move to table
|
||||
search_input = self.query_one("#search-input", Input)
|
||||
if search_input.has_focus:
|
||||
table = self.query_one("#model-table", DataTable)
|
||||
table.focus()
|
||||
# If in table or anywhere else, select current row
|
||||
else:
|
||||
table = self.query_one("#model-table", DataTable)
|
||||
# Get the currently highlighted row
|
||||
if table.cursor_row is not None:
|
||||
try:
|
||||
row_key = table.get_row_at(table.cursor_row)
|
||||
if row_key:
|
||||
row_index = int(row_key[0]) - 1
|
||||
if 0 <= row_index < len(self.filtered_models):
|
||||
selected = self.filtered_models[row_index]
|
||||
self.dismiss(selected)
|
||||
except (ValueError, IndexError, AttributeError):
|
||||
# Fall back to previously selected model
|
||||
if self.selected_model:
|
||||
self.dismiss(self.selected_model)
|
||||
Reference in New Issue
Block a user