266 lines
9.0 KiB
Swift
266 lines
9.0 KiB
Swift
//
|
|
// ContextSelectionService.swift
|
|
// oAI
|
|
//
|
|
// Smart context selection for AI conversations
|
|
// Selects relevant messages instead of sending entire history
|
|
//
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
|
// Copyright (C) 2026 Rune Olsen
|
|
//
|
|
// This file is part of oAI.
|
|
//
|
|
// oAI is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as
|
|
// published by the Free Software Foundation, either version 3 of the
|
|
// License, or (at your option) any later version.
|
|
//
|
|
// oAI is distributed in the hope that it will be useful, but WITHOUT
|
|
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
|
|
// or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
|
|
// Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public
|
|
// License along with oAI. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
|
|
import Foundation
|
|
import os
|
|
|
|
// MARK: - Context Window
|
|
|
|
struct ContextWindow {
|
|
let messages: [Message]
|
|
let summaries: [String]
|
|
let totalTokens: Int
|
|
let excludedCount: Int
|
|
}
|
|
|
|
// MARK: - Selection Strategy
|
|
|
|
enum SelectionStrategy {
|
|
case allMessages // Memory ON (old behavior): send all messages
|
|
case lastMessageOnly // Memory OFF: send only last message
|
|
case smart // NEW: intelligent selection
|
|
}
|
|
|
|
// MARK: - Context Selection Service
|
|
|
|
final class ContextSelectionService {
|
|
static let shared = ContextSelectionService()
|
|
|
|
private init() {}
|
|
|
|
/// Select context messages using the specified strategy
|
|
func selectContext(
|
|
allMessages: [Message],
|
|
strategy: SelectionStrategy,
|
|
maxTokens: Int?,
|
|
currentQuery: String? = nil,
|
|
conversationId: UUID? = nil
|
|
) -> ContextWindow {
|
|
switch strategy {
|
|
case .allMessages:
|
|
return allMessagesContext(allMessages)
|
|
|
|
case .lastMessageOnly:
|
|
return lastMessageOnlyContext(allMessages)
|
|
|
|
case .smart:
|
|
guard let maxTokens = maxTokens else {
|
|
// Fallback to all messages if no token limit
|
|
return allMessagesContext(allMessages)
|
|
}
|
|
return smartSelection(allMessages: allMessages, maxTokens: maxTokens, conversationId: conversationId)
|
|
}
|
|
}
|
|
|
|
// MARK: - Simple Strategies
|
|
|
|
private func allMessagesContext(_ messages: [Message]) -> ContextWindow {
|
|
ContextWindow(
|
|
messages: messages,
|
|
summaries: [],
|
|
totalTokens: estimateTokens(messages),
|
|
excludedCount: 0
|
|
)
|
|
}
|
|
|
|
private func lastMessageOnlyContext(_ messages: [Message]) -> ContextWindow {
|
|
guard let last = messages.last else {
|
|
return ContextWindow(messages: [], summaries: [], totalTokens: 0, excludedCount: 0)
|
|
}
|
|
return ContextWindow(
|
|
messages: [last],
|
|
summaries: [],
|
|
totalTokens: estimateTokens([last]),
|
|
excludedCount: messages.count - 1
|
|
)
|
|
}
|
|
|
|
// MARK: - Smart Selection Algorithm
|
|
|
|
private func smartSelection(allMessages: [Message], maxTokens: Int, conversationId: UUID? = nil) -> ContextWindow {
|
|
guard !allMessages.isEmpty else {
|
|
return ContextWindow(messages: [], summaries: [], totalTokens: 0, excludedCount: 0)
|
|
}
|
|
|
|
// Filter out system messages (tools)
|
|
let chatMessages = allMessages.filter { $0.role == .user || $0.role == .assistant }
|
|
|
|
// Step 1: Always include last N messages (recent context)
|
|
let recentCount = min(10, chatMessages.count)
|
|
let recentMessages = Array(chatMessages.suffix(recentCount))
|
|
var selectedMessages = recentMessages
|
|
var currentTokens = estimateTokens(recentMessages)
|
|
|
|
Log.ui.debug("Smart selection: starting with last \(recentCount) messages (\(currentTokens) tokens)")
|
|
|
|
// Step 2: Add starred messages from earlier in conversation
|
|
let olderMessages = chatMessages.dropLast(recentCount)
|
|
var starredMessages: [Message] = []
|
|
|
|
for message in olderMessages {
|
|
// Check if message is starred
|
|
if isMessageStarred(message) {
|
|
let msgTokens = estimateTokens([message])
|
|
if currentTokens + msgTokens <= maxTokens {
|
|
starredMessages.append(message)
|
|
currentTokens += msgTokens
|
|
} else {
|
|
Log.ui.debug("Smart selection: token budget exceeded, stopping at \(selectedMessages.count + starredMessages.count) messages")
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Step 3: Add important messages (high cost, long content)
|
|
var importantMessages: [Message] = []
|
|
if currentTokens < maxTokens {
|
|
for message in olderMessages {
|
|
// Skip if already starred
|
|
if starredMessages.contains(where: { $0.id == message.id }) {
|
|
continue
|
|
}
|
|
|
|
let importance = getImportanceScore(message)
|
|
if importance > 0.5 { // Threshold for "important"
|
|
let msgTokens = estimateTokens([message])
|
|
if currentTokens + msgTokens <= maxTokens {
|
|
importantMessages.append(message)
|
|
currentTokens += msgTokens
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Combine: starred + important + recent (in chronological order)
|
|
let allSelected = (starredMessages + importantMessages + recentMessages)
|
|
.sorted { $0.timestamp < $1.timestamp }
|
|
|
|
// Remove duplicates while preserving order
|
|
var seen = Set<UUID>()
|
|
selectedMessages = allSelected.filter { message in
|
|
if seen.contains(message.id) {
|
|
return false
|
|
}
|
|
seen.insert(message.id)
|
|
return true
|
|
}
|
|
|
|
let excludedCount = chatMessages.count - selectedMessages.count
|
|
|
|
// Get summaries for excluded message ranges
|
|
var summaries: [String] = []
|
|
if excludedCount > 0, let conversationId = conversationId {
|
|
summaries = getSummariesForExcludedRange(
|
|
conversationId: conversationId,
|
|
totalMessages: chatMessages.count,
|
|
selectedCount: selectedMessages.count
|
|
)
|
|
}
|
|
|
|
Log.ui.info("Smart selection: selected \(selectedMessages.count)/\(chatMessages.count) messages (\(currentTokens) tokens, excluded: \(excludedCount), summaries: \(summaries.count))")
|
|
|
|
return ContextWindow(
|
|
messages: selectedMessages,
|
|
summaries: summaries,
|
|
totalTokens: currentTokens,
|
|
excludedCount: excludedCount
|
|
)
|
|
}
|
|
|
|
/// Get summaries for excluded message ranges
|
|
private func getSummariesForExcludedRange(
|
|
conversationId: UUID,
|
|
totalMessages: Int,
|
|
selectedCount: Int
|
|
) -> [String] {
|
|
guard let summaryRecords = try? DatabaseService.shared.getConversationSummaries(conversationId: conversationId) else {
|
|
return []
|
|
}
|
|
|
|
var summaries: [String] = []
|
|
for record in summaryRecords {
|
|
// Only include summaries for messages that were excluded
|
|
if record.end_message_index < (totalMessages - selectedCount) {
|
|
summaries.append(record.summary)
|
|
}
|
|
}
|
|
|
|
return summaries
|
|
}
|
|
|
|
// MARK: - Importance Scoring
|
|
|
|
/// Calculate importance score (0.0 - 1.0) for a message
|
|
private func getImportanceScore(_ message: Message) -> Double {
|
|
var score = 0.0
|
|
|
|
// Factor 1: Cost (expensive calls are important)
|
|
if let cost = message.cost {
|
|
let costScore = min(1.0, cost / 0.01) // $0.01+ = max score
|
|
score += costScore * 0.5
|
|
}
|
|
|
|
// Factor 2: Length (detailed messages are important)
|
|
let contentLength = Double(message.content.count)
|
|
let lengthScore = min(1.0, contentLength / 2000.0) // 2000+ chars = max score
|
|
score += lengthScore * 0.3
|
|
|
|
// Factor 3: Token count (if available)
|
|
if let tokens = message.tokens {
|
|
let tokenScore = min(1.0, Double(tokens) / 1000.0) // 1000+ tokens = max score
|
|
score += tokenScore * 0.2
|
|
}
|
|
|
|
return min(1.0, score)
|
|
}
|
|
|
|
/// Check if a message is starred by the user
|
|
private func isMessageStarred(_ message: Message) -> Bool {
|
|
guard let metadata = try? DatabaseService.shared.getMessageMetadata(messageId: message.id) else {
|
|
return false
|
|
}
|
|
return metadata.user_starred == 1
|
|
}
|
|
|
|
// MARK: - Token Estimation
|
|
|
|
/// Estimate token count for messages (rough approximation)
|
|
private func estimateTokens(_ messages: [Message]) -> Int {
|
|
var total = 0
|
|
for message in messages {
|
|
if let tokens = message.tokens {
|
|
total += tokens
|
|
} else {
|
|
// Rough estimate: 1 token ≈ 4 characters
|
|
total += message.content.count / 4
|
|
}
|
|
}
|
|
return total
|
|
}
|
|
}
|