"""Model selector screen for oAI TUI.""" from typing import List, Optional from textual.app import ComposeResult from textual.containers import Container, Vertical from textual.screen import ModalScreen from textual.widgets import Button, DataTable, Input, Label, Static class ModelSelectorScreen(ModalScreen[Optional[dict]]): """Modal screen for selecting an AI model.""" DEFAULT_CSS = """ ModelSelectorScreen { align: center middle; } ModelSelectorScreen > Container { width: 90%; height: 85%; background: #1e1e1e; border: solid #555555; layout: vertical; } ModelSelectorScreen .header { height: 3; width: 100%; background: #2d2d2d; color: #cccccc; padding: 0 2; content-align: center middle; } ModelSelectorScreen .search-input { height: 3; width: 100%; background: #2a2a2a; border: solid #555555; margin: 0 0 1 0; } ModelSelectorScreen .search-input:focus { border: solid #888888; } ModelSelectorScreen DataTable { height: 1fr; width: 100%; background: #1e1e1e; border: solid #555555; } ModelSelectorScreen .footer { height: 5; width: 100%; background: #2d2d2d; padding: 1 2; align: center middle; } ModelSelectorScreen Button { margin: 0 1; } """ def __init__(self, models: List[dict], current_model: Optional[str] = None): super().__init__() self.all_models = models self.filtered_models = models self.current_model = current_model self.selected_model: Optional[dict] = None def compose(self) -> ComposeResult: """Compose the screen.""" with Container(): yield Static( f"[bold]Select Model[/] [dim]({len(self.all_models)} available)[/]", classes="header" ) yield Input(placeholder="Search to filter models...", id="search-input", classes="search-input") yield DataTable(id="model-table", cursor_type="row", show_header=True, zebra_stripes=True) with Vertical(classes="footer"): yield Button("Select", id="select", variant="success") yield Button("Cancel", id="cancel", variant="error") def on_mount(self) -> None: """Initialize the table when mounted.""" table = self.query_one("#model-table", DataTable) # Add columns table.add_column("#", width=5) table.add_column("Model ID", width=35) table.add_column("Name", width=30) table.add_column("Context", width=10) table.add_column("Price", width=12) table.add_column("Img", width=4) table.add_column("Tools", width=6) table.add_column("Online", width=7) # Populate table self._populate_table() # Focus table if list is small (fits on screen), otherwise focus search if len(self.filtered_models) <= 20: table.focus() else: search_input = self.query_one("#search-input", Input) search_input.focus() def _populate_table(self) -> None: """Populate the table with models.""" table = self.query_one("#model-table", DataTable) table.clear() rows_added = 0 for idx, model in enumerate(self.filtered_models, 1): try: model_id = model.get("id", "") name = model.get("name", "") context = str(model.get("context_length", "N/A")) # Format pricing pricing = model.get("pricing", {}) prompt_price = pricing.get("prompt", "0") completion_price = pricing.get("completion", "0") # Convert to numbers and format try: prompt = float(prompt_price) * 1000000 # Convert to per 1M tokens completion = float(completion_price) * 1000000 if prompt == 0 and completion == 0: price = "Free" else: price = f"${prompt:.2f}/${completion:.2f}" except: price = "N/A" # Check capabilities architecture = model.get("architecture", {}) modality = architecture.get("modality", "") supported_params = model.get("supported_parameters", []) # Vision support: check if modality contains "image" supports_vision = "image" in modality # Tool support: check if "tools" or "tool_choice" in supported_parameters supports_tools = "tools" in supported_params or "tool_choice" in supported_params # Online support: check if model can use :online suffix (most models can) # Models that already have :online in their ID support it supports_online = ":online" in model_id or model_id not in ["openrouter/free"] # Format capability indicators img_indicator = "✓" if supports_vision else "-" tools_indicator = "✓" if supports_tools else "-" web_indicator = "✓" if supports_online else "-" # Add row table.add_row( str(idx), model_id, name, context, price, img_indicator, tools_indicator, web_indicator, key=str(idx) ) rows_added += 1 except Exception: # Silently skip rows that fail to add pass def on_input_changed(self, event: Input.Changed) -> None: """Filter models based on search input.""" if event.input.id != "search-input": return search_term = event.value.lower() if not search_term: self.filtered_models = self.all_models else: self.filtered_models = [ m for m in self.all_models if search_term in m.get("id", "").lower() or search_term in m.get("name", "").lower() ] self._populate_table() def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None: """Handle row selection (click or arrow navigation).""" try: row_index = int(event.row_key.value) - 1 if 0 <= row_index < len(self.filtered_models): self.selected_model = self.filtered_models[row_index] except (ValueError, IndexError): pass def on_data_table_row_highlighted(self, event) -> None: """Handle row highlight (arrow key navigation).""" try: table = self.query_one("#model-table", DataTable) if table.cursor_row is not None: row_data = table.get_row_at(table.cursor_row) if row_data: row_index = int(row_data[0]) - 1 if 0 <= row_index < len(self.filtered_models): self.selected_model = self.filtered_models[row_index] except (ValueError, IndexError, AttributeError): pass def on_button_pressed(self, event: Button.Pressed) -> None: """Handle button press.""" if event.button.id == "select": if self.selected_model: self.dismiss(self.selected_model) else: # No selection, dismiss without result self.dismiss(None) else: self.dismiss(None) def on_key(self, event) -> None: """Handle keyboard shortcuts.""" if event.key == "escape": self.dismiss(None) elif event.key == "enter": # If in search input, move to table search_input = self.query_one("#search-input", Input) if search_input.has_focus: table = self.query_one("#model-table", DataTable) table.focus() # If in table or anywhere else, select current row else: table = self.query_one("#model-table", DataTable) # Get the currently highlighted row if table.cursor_row is not None: try: row_key = table.get_row_at(table.cursor_row) if row_key: row_index = int(row_key[0]) - 1 if 0 <= row_index < len(self.filtered_models): selected = self.filtered_models[row_index] self.dismiss(selected) except (ValueError, IndexError, AttributeError): # Fall back to previously selected model if self.selected_model: self.dismiss(self.selected_model)