Small feature changes and bug fixes
This commit is contained in:
247
oAI/Services/ContextSelectionService.swift
Normal file
247
oAI/Services/ContextSelectionService.swift
Normal file
@@ -0,0 +1,247 @@
|
||||
//
|
||||
// ContextSelectionService.swift
|
||||
// oAI
|
||||
//
|
||||
// Smart context selection for AI conversations
|
||||
// Selects relevant messages instead of sending entire history
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import os
|
||||
|
||||
// MARK: - Context Window
|
||||
|
||||
struct ContextWindow {
|
||||
let messages: [Message]
|
||||
let summaries: [String]
|
||||
let totalTokens: Int
|
||||
let excludedCount: Int
|
||||
}
|
||||
|
||||
// MARK: - Selection Strategy
|
||||
|
||||
enum SelectionStrategy {
|
||||
case allMessages // Memory ON (old behavior): send all messages
|
||||
case lastMessageOnly // Memory OFF: send only last message
|
||||
case smart // NEW: intelligent selection
|
||||
}
|
||||
|
||||
// MARK: - Context Selection Service
|
||||
|
||||
final class ContextSelectionService {
|
||||
static let shared = ContextSelectionService()
|
||||
|
||||
private init() {}
|
||||
|
||||
/// Select context messages using the specified strategy
|
||||
func selectContext(
|
||||
allMessages: [Message],
|
||||
strategy: SelectionStrategy,
|
||||
maxTokens: Int?,
|
||||
currentQuery: String? = nil,
|
||||
conversationId: UUID? = nil
|
||||
) -> ContextWindow {
|
||||
switch strategy {
|
||||
case .allMessages:
|
||||
return allMessagesContext(allMessages)
|
||||
|
||||
case .lastMessageOnly:
|
||||
return lastMessageOnlyContext(allMessages)
|
||||
|
||||
case .smart:
|
||||
guard let maxTokens = maxTokens else {
|
||||
// Fallback to all messages if no token limit
|
||||
return allMessagesContext(allMessages)
|
||||
}
|
||||
return smartSelection(allMessages: allMessages, maxTokens: maxTokens, conversationId: conversationId)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Simple Strategies
|
||||
|
||||
private func allMessagesContext(_ messages: [Message]) -> ContextWindow {
|
||||
ContextWindow(
|
||||
messages: messages,
|
||||
summaries: [],
|
||||
totalTokens: estimateTokens(messages),
|
||||
excludedCount: 0
|
||||
)
|
||||
}
|
||||
|
||||
private func lastMessageOnlyContext(_ messages: [Message]) -> ContextWindow {
|
||||
guard let last = messages.last else {
|
||||
return ContextWindow(messages: [], summaries: [], totalTokens: 0, excludedCount: 0)
|
||||
}
|
||||
return ContextWindow(
|
||||
messages: [last],
|
||||
summaries: [],
|
||||
totalTokens: estimateTokens([last]),
|
||||
excludedCount: messages.count - 1
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Smart Selection Algorithm
|
||||
|
||||
private func smartSelection(allMessages: [Message], maxTokens: Int, conversationId: UUID? = nil) -> ContextWindow {
|
||||
guard !allMessages.isEmpty else {
|
||||
return ContextWindow(messages: [], summaries: [], totalTokens: 0, excludedCount: 0)
|
||||
}
|
||||
|
||||
// Filter out system messages (tools)
|
||||
let chatMessages = allMessages.filter { $0.role == .user || $0.role == .assistant }
|
||||
|
||||
// Step 1: Always include last N messages (recent context)
|
||||
let recentCount = min(10, chatMessages.count)
|
||||
let recentMessages = Array(chatMessages.suffix(recentCount))
|
||||
var selectedMessages = recentMessages
|
||||
var currentTokens = estimateTokens(recentMessages)
|
||||
|
||||
Log.ui.debug("Smart selection: starting with last \(recentCount) messages (\(currentTokens) tokens)")
|
||||
|
||||
// Step 2: Add starred messages from earlier in conversation
|
||||
let olderMessages = chatMessages.dropLast(recentCount)
|
||||
var starredMessages: [Message] = []
|
||||
|
||||
for message in olderMessages {
|
||||
// Check if message is starred
|
||||
if isMessageStarred(message) {
|
||||
let msgTokens = estimateTokens([message])
|
||||
if currentTokens + msgTokens <= maxTokens {
|
||||
starredMessages.append(message)
|
||||
currentTokens += msgTokens
|
||||
} else {
|
||||
Log.ui.debug("Smart selection: token budget exceeded, stopping at \(selectedMessages.count + starredMessages.count) messages")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Add important messages (high cost, long content)
|
||||
var importantMessages: [Message] = []
|
||||
if currentTokens < maxTokens {
|
||||
for message in olderMessages {
|
||||
// Skip if already starred
|
||||
if starredMessages.contains(where: { $0.id == message.id }) {
|
||||
continue
|
||||
}
|
||||
|
||||
let importance = getImportanceScore(message)
|
||||
if importance > 0.5 { // Threshold for "important"
|
||||
let msgTokens = estimateTokens([message])
|
||||
if currentTokens + msgTokens <= maxTokens {
|
||||
importantMessages.append(message)
|
||||
currentTokens += msgTokens
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Combine: starred + important + recent (in chronological order)
|
||||
let allSelected = (starredMessages + importantMessages + recentMessages)
|
||||
.sorted { $0.timestamp < $1.timestamp }
|
||||
|
||||
// Remove duplicates while preserving order
|
||||
var seen = Set<UUID>()
|
||||
selectedMessages = allSelected.filter { message in
|
||||
if seen.contains(message.id) {
|
||||
return false
|
||||
}
|
||||
seen.insert(message.id)
|
||||
return true
|
||||
}
|
||||
|
||||
let excludedCount = chatMessages.count - selectedMessages.count
|
||||
|
||||
// Get summaries for excluded message ranges
|
||||
var summaries: [String] = []
|
||||
if excludedCount > 0, let conversationId = conversationId {
|
||||
summaries = getSummariesForExcludedRange(
|
||||
conversationId: conversationId,
|
||||
totalMessages: chatMessages.count,
|
||||
selectedCount: selectedMessages.count
|
||||
)
|
||||
}
|
||||
|
||||
Log.ui.info("Smart selection: selected \(selectedMessages.count)/\(chatMessages.count) messages (\(currentTokens) tokens, excluded: \(excludedCount), summaries: \(summaries.count))")
|
||||
|
||||
return ContextWindow(
|
||||
messages: selectedMessages,
|
||||
summaries: summaries,
|
||||
totalTokens: currentTokens,
|
||||
excludedCount: excludedCount
|
||||
)
|
||||
}
|
||||
|
||||
/// Get summaries for excluded message ranges
|
||||
private func getSummariesForExcludedRange(
|
||||
conversationId: UUID,
|
||||
totalMessages: Int,
|
||||
selectedCount: Int
|
||||
) -> [String] {
|
||||
guard let summaryRecords = try? DatabaseService.shared.getConversationSummaries(conversationId: conversationId) else {
|
||||
return []
|
||||
}
|
||||
|
||||
var summaries: [String] = []
|
||||
for record in summaryRecords {
|
||||
// Only include summaries for messages that were excluded
|
||||
if record.end_message_index < (totalMessages - selectedCount) {
|
||||
summaries.append(record.summary)
|
||||
}
|
||||
}
|
||||
|
||||
return summaries
|
||||
}
|
||||
|
||||
// MARK: - Importance Scoring
|
||||
|
||||
/// Calculate importance score (0.0 - 1.0) for a message
|
||||
private func getImportanceScore(_ message: Message) -> Double {
|
||||
var score = 0.0
|
||||
|
||||
// Factor 1: Cost (expensive calls are important)
|
||||
if let cost = message.cost {
|
||||
let costScore = min(1.0, cost / 0.01) // $0.01+ = max score
|
||||
score += costScore * 0.5
|
||||
}
|
||||
|
||||
// Factor 2: Length (detailed messages are important)
|
||||
let contentLength = Double(message.content.count)
|
||||
let lengthScore = min(1.0, contentLength / 2000.0) // 2000+ chars = max score
|
||||
score += lengthScore * 0.3
|
||||
|
||||
// Factor 3: Token count (if available)
|
||||
if let tokens = message.tokens {
|
||||
let tokenScore = min(1.0, Double(tokens) / 1000.0) // 1000+ tokens = max score
|
||||
score += tokenScore * 0.2
|
||||
}
|
||||
|
||||
return min(1.0, score)
|
||||
}
|
||||
|
||||
/// Check if a message is starred by the user
|
||||
private func isMessageStarred(_ message: Message) -> Bool {
|
||||
guard let metadata = try? DatabaseService.shared.getMessageMetadata(messageId: message.id) else {
|
||||
return false
|
||||
}
|
||||
return metadata.user_starred == 1
|
||||
}
|
||||
|
||||
// MARK: - Token Estimation
|
||||
|
||||
/// Estimate token count for messages (rough approximation)
|
||||
private func estimateTokens(_ messages: [Message]) -> Int {
|
||||
var total = 0
|
||||
for message in messages {
|
||||
if let tokens = message.tokens {
|
||||
total += tokens
|
||||
} else {
|
||||
// Rough estimate: 1 token ≈ 4 characters
|
||||
total += message.content.count / 4
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user