Files
oai-swift/oAI/Services/ContextSelectionService.swift

266 lines
9.0 KiB
Swift

//
// ContextSelectionService.swift
// oAI
//
// Smart context selection for AI conversations
// Selects relevant messages instead of sending entire history
//
// SPDX-License-Identifier: AGPL-3.0-or-later
// Copyright (C) 2026 Rune Olsen
//
// This file is part of oAI.
//
// oAI is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// oAI is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
// or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
// Public License for more details.
//
// You should have received a copy of the GNU Affero General Public
// License along with oAI. If not, see <https://www.gnu.org/licenses/>.
import Foundation
import os
// MARK: - Context Window
struct ContextWindow {
let messages: [Message]
let summaries: [String]
let totalTokens: Int
let excludedCount: Int
}
// MARK: - Selection Strategy
enum SelectionStrategy {
case allMessages // Memory ON (old behavior): send all messages
case lastMessageOnly // Memory OFF: send only last message
case smart // NEW: intelligent selection
}
// MARK: - Context Selection Service
final class ContextSelectionService {
static let shared = ContextSelectionService()
private init() {}
/// Select context messages using the specified strategy
func selectContext(
allMessages: [Message],
strategy: SelectionStrategy,
maxTokens: Int?,
currentQuery: String? = nil,
conversationId: UUID? = nil
) -> ContextWindow {
switch strategy {
case .allMessages:
return allMessagesContext(allMessages)
case .lastMessageOnly:
return lastMessageOnlyContext(allMessages)
case .smart:
guard let maxTokens = maxTokens else {
// Fallback to all messages if no token limit
return allMessagesContext(allMessages)
}
return smartSelection(allMessages: allMessages, maxTokens: maxTokens, conversationId: conversationId)
}
}
// MARK: - Simple Strategies
private func allMessagesContext(_ messages: [Message]) -> ContextWindow {
ContextWindow(
messages: messages,
summaries: [],
totalTokens: estimateTokens(messages),
excludedCount: 0
)
}
private func lastMessageOnlyContext(_ messages: [Message]) -> ContextWindow {
guard let last = messages.last else {
return ContextWindow(messages: [], summaries: [], totalTokens: 0, excludedCount: 0)
}
return ContextWindow(
messages: [last],
summaries: [],
totalTokens: estimateTokens([last]),
excludedCount: messages.count - 1
)
}
// MARK: - Smart Selection Algorithm
private func smartSelection(allMessages: [Message], maxTokens: Int, conversationId: UUID? = nil) -> ContextWindow {
guard !allMessages.isEmpty else {
return ContextWindow(messages: [], summaries: [], totalTokens: 0, excludedCount: 0)
}
// Filter out system messages (tools)
let chatMessages = allMessages.filter { $0.role == .user || $0.role == .assistant }
// Step 1: Always include last N messages (recent context)
let recentCount = min(10, chatMessages.count)
let recentMessages = Array(chatMessages.suffix(recentCount))
var selectedMessages = recentMessages
var currentTokens = estimateTokens(recentMessages)
Log.ui.debug("Smart selection: starting with last \(recentCount) messages (\(currentTokens) tokens)")
// Step 2: Add starred messages from earlier in conversation
let olderMessages = chatMessages.dropLast(recentCount)
var starredMessages: [Message] = []
for message in olderMessages {
// Check if message is starred
if isMessageStarred(message) {
let msgTokens = estimateTokens([message])
if currentTokens + msgTokens <= maxTokens {
starredMessages.append(message)
currentTokens += msgTokens
} else {
Log.ui.debug("Smart selection: token budget exceeded, stopping at \(selectedMessages.count + starredMessages.count) messages")
break
}
}
}
// Step 3: Add important messages (high cost, long content)
var importantMessages: [Message] = []
if currentTokens < maxTokens {
for message in olderMessages {
// Skip if already starred
if starredMessages.contains(where: { $0.id == message.id }) {
continue
}
let importance = getImportanceScore(message)
if importance > 0.5 { // Threshold for "important"
let msgTokens = estimateTokens([message])
if currentTokens + msgTokens <= maxTokens {
importantMessages.append(message)
currentTokens += msgTokens
} else {
break
}
}
}
}
// Combine: starred + important + recent (in chronological order)
let allSelected = (starredMessages + importantMessages + recentMessages)
.sorted { $0.timestamp < $1.timestamp }
// Remove duplicates while preserving order
var seen = Set<UUID>()
selectedMessages = allSelected.filter { message in
if seen.contains(message.id) {
return false
}
seen.insert(message.id)
return true
}
let excludedCount = chatMessages.count - selectedMessages.count
// Get summaries for excluded message ranges
var summaries: [String] = []
if excludedCount > 0, let conversationId = conversationId {
summaries = getSummariesForExcludedRange(
conversationId: conversationId,
totalMessages: chatMessages.count,
selectedCount: selectedMessages.count
)
}
Log.ui.info("Smart selection: selected \(selectedMessages.count)/\(chatMessages.count) messages (\(currentTokens) tokens, excluded: \(excludedCount), summaries: \(summaries.count))")
return ContextWindow(
messages: selectedMessages,
summaries: summaries,
totalTokens: currentTokens,
excludedCount: excludedCount
)
}
/// Get summaries for excluded message ranges
private func getSummariesForExcludedRange(
conversationId: UUID,
totalMessages: Int,
selectedCount: Int
) -> [String] {
guard let summaryRecords = try? DatabaseService.shared.getConversationSummaries(conversationId: conversationId) else {
return []
}
var summaries: [String] = []
for record in summaryRecords {
// Only include summaries for messages that were excluded
if record.end_message_index < (totalMessages - selectedCount) {
summaries.append(record.summary)
}
}
return summaries
}
// MARK: - Importance Scoring
/// Calculate importance score (0.0 - 1.0) for a message
private func getImportanceScore(_ message: Message) -> Double {
var score = 0.0
// Factor 1: Cost (expensive calls are important)
if let cost = message.cost {
let costScore = min(1.0, cost / 0.01) // $0.01+ = max score
score += costScore * 0.5
}
// Factor 2: Length (detailed messages are important)
let contentLength = Double(message.content.count)
let lengthScore = min(1.0, contentLength / 2000.0) // 2000+ chars = max score
score += lengthScore * 0.3
// Factor 3: Token count (if available)
if let tokens = message.tokens {
let tokenScore = min(1.0, Double(tokens) / 1000.0) // 1000+ tokens = max score
score += tokenScore * 0.2
}
return min(1.0, score)
}
/// Check if a message is starred by the user
private func isMessageStarred(_ message: Message) -> Bool {
guard let metadata = try? DatabaseService.shared.getMessageMetadata(messageId: message.id) else {
return false
}
return metadata.user_starred == 1
}
// MARK: - Token Estimation
/// Estimate token count for messages (rough approximation)
private func estimateTokens(_ messages: [Message]) -> Int {
var total = 0
for message in messages {
if let tokens = message.tokens {
total += tokens
} else {
// Rough estimate: 1 token 4 characters
total += message.content.count / 4
}
}
return total
}
}