Small feature changes and bug fixes
This commit is contained in:
408
oAI/Services/EmbeddingService.swift
Normal file
408
oAI/Services/EmbeddingService.swift
Normal file
@@ -0,0 +1,408 @@
|
||||
//
|
||||
// EmbeddingService.swift
|
||||
// oAI
|
||||
//
|
||||
// Embedding generation and semantic search
|
||||
// Supports multiple providers: OpenAI, OpenRouter, Google
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import os
|
||||
|
||||
// MARK: - Embedding Provider
|
||||
|
||||
enum EmbeddingProvider {
|
||||
case openai(model: String)
|
||||
case openrouter(model: String)
|
||||
case google(model: String)
|
||||
|
||||
var defaultModel: String {
|
||||
switch self {
|
||||
case .openai: return "text-embedding-3-small"
|
||||
case .openrouter: return "openai/text-embedding-3-small"
|
||||
case .google: return "text-embedding-004"
|
||||
}
|
||||
}
|
||||
|
||||
var dimension: Int {
|
||||
switch self {
|
||||
case .openai(let model):
|
||||
return model == "text-embedding-3-large" ? 3072 : 1536
|
||||
case .openrouter(let model):
|
||||
if model.contains("text-embedding-3-large") {
|
||||
return 3072
|
||||
} else if model.contains("qwen3-embedding-8b") {
|
||||
return 8192
|
||||
} else {
|
||||
return 1536 // Default for most models
|
||||
}
|
||||
case .google:
|
||||
return 768
|
||||
}
|
||||
}
|
||||
|
||||
var displayName: String {
|
||||
switch self {
|
||||
case .openai: return "OpenAI"
|
||||
case .openrouter: return "OpenRouter"
|
||||
case .google: return "Google"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Embedding Service
|
||||
|
||||
final class EmbeddingService {
|
||||
static let shared = EmbeddingService()
|
||||
|
||||
private let settings = SettingsService.shared
|
||||
|
||||
private init() {}
|
||||
|
||||
// MARK: - Provider Detection
|
||||
|
||||
/// Get the embedding provider based on user's selection in settings
|
||||
func getSelectedProvider() -> EmbeddingProvider? {
|
||||
let selectedModel = settings.embeddingProvider
|
||||
|
||||
// Map user's selection to provider
|
||||
switch selectedModel {
|
||||
case "openai-small":
|
||||
guard let key = settings.openaiAPIKey, !key.isEmpty else { return nil }
|
||||
return .openai(model: "text-embedding-3-small")
|
||||
case "openai-large":
|
||||
guard let key = settings.openaiAPIKey, !key.isEmpty else { return nil }
|
||||
return .openai(model: "text-embedding-3-large")
|
||||
case "openrouter-openai-small":
|
||||
guard let key = settings.openrouterAPIKey, !key.isEmpty else { return nil }
|
||||
return .openrouter(model: "openai/text-embedding-3-small")
|
||||
case "openrouter-openai-large":
|
||||
guard let key = settings.openrouterAPIKey, !key.isEmpty else { return nil }
|
||||
return .openrouter(model: "openai/text-embedding-3-large")
|
||||
case "openrouter-qwen":
|
||||
guard let key = settings.openrouterAPIKey, !key.isEmpty else { return nil }
|
||||
return .openrouter(model: "qwen/qwen3-embedding-8b")
|
||||
case "google-gemini":
|
||||
guard let key = settings.googleAPIKey, !key.isEmpty else { return nil }
|
||||
return .google(model: "text-embedding-004")
|
||||
default:
|
||||
// Fall back to best available
|
||||
return getBestAvailableProvider()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the best available embedding provider based on user's API keys (priority: OpenAI → OpenRouter → Google)
|
||||
func getBestAvailableProvider() -> EmbeddingProvider? {
|
||||
// Priority: OpenAI → OpenRouter → Google
|
||||
if let key = settings.openaiAPIKey, !key.isEmpty {
|
||||
return .openai(model: "text-embedding-3-small")
|
||||
}
|
||||
|
||||
if let key = settings.openrouterAPIKey, !key.isEmpty {
|
||||
return .openrouter(model: "openai/text-embedding-3-small")
|
||||
}
|
||||
|
||||
if let key = settings.googleAPIKey, !key.isEmpty {
|
||||
return .google(model: "text-embedding-004")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
/// Check if embeddings are available (user has at least one compatible provider)
|
||||
var isAvailable: Bool {
|
||||
return getBestAvailableProvider() != nil
|
||||
}
|
||||
|
||||
// MARK: - Embedding Generation
|
||||
|
||||
/// Generate embedding for text using the configured provider
|
||||
func generateEmbedding(text: String, provider: EmbeddingProvider) async throws -> [Float] {
|
||||
switch provider {
|
||||
case .openai(let model):
|
||||
return try await generateOpenAIEmbedding(text: text, model: model)
|
||||
case .openrouter(let model):
|
||||
return try await generateOpenRouterEmbedding(text: text, model: model)
|
||||
case .google(let model):
|
||||
return try await generateGoogleEmbedding(text: text, model: model)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate OpenAI embedding
|
||||
private func generateOpenAIEmbedding(text: String, model: String) async throws -> [Float] {
|
||||
guard let apiKey = settings.openaiAPIKey, !apiKey.isEmpty else {
|
||||
throw EmbeddingError.missingAPIKey("OpenAI")
|
||||
}
|
||||
|
||||
let url = URL(string: "https://api.openai.com/v1/embeddings")!
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "POST"
|
||||
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
|
||||
let body: [String: Any] = [
|
||||
"input": text,
|
||||
"model": model
|
||||
]
|
||||
request.httpBody = try JSONSerialization.data(withJSONObject: body)
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
throw EmbeddingError.invalidResponse
|
||||
}
|
||||
|
||||
guard httpResponse.statusCode == 200 else {
|
||||
let errorMessage = String(data: data, encoding: .utf8) ?? "Unknown error"
|
||||
Log.api.error("OpenAI embedding error (\(httpResponse.statusCode)): \(errorMessage)")
|
||||
throw EmbeddingError.apiError(httpResponse.statusCode, errorMessage)
|
||||
}
|
||||
|
||||
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
|
||||
guard let dataArray = json?["data"] as? [[String: Any]],
|
||||
let first = dataArray.first,
|
||||
let embedding = first["embedding"] as? [Double] else {
|
||||
throw EmbeddingError.invalidResponse
|
||||
}
|
||||
|
||||
return embedding.map { Float($0) }
|
||||
}
|
||||
|
||||
/// Generate OpenRouter embedding (OpenAI-compatible API)
|
||||
private func generateOpenRouterEmbedding(text: String, model: String) async throws -> [Float] {
|
||||
guard let apiKey = settings.openrouterAPIKey, !apiKey.isEmpty else {
|
||||
throw EmbeddingError.missingAPIKey("OpenRouter")
|
||||
}
|
||||
|
||||
let url = URL(string: "https://openrouter.ai/api/v1/embeddings")!
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "POST"
|
||||
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
request.setValue("https://github.com/yourusername/oAI", forHTTPHeaderField: "HTTP-Referer")
|
||||
|
||||
let body: [String: Any] = [
|
||||
"input": text,
|
||||
"model": model
|
||||
]
|
||||
request.httpBody = try JSONSerialization.data(withJSONObject: body)
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
throw EmbeddingError.invalidResponse
|
||||
}
|
||||
|
||||
guard httpResponse.statusCode == 200 else {
|
||||
let errorMessage = String(data: data, encoding: .utf8) ?? "Unknown error"
|
||||
Log.api.error("OpenRouter embedding error (\(httpResponse.statusCode)): \(errorMessage)")
|
||||
throw EmbeddingError.apiError(httpResponse.statusCode, errorMessage)
|
||||
}
|
||||
|
||||
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
|
||||
guard let dataArray = json?["data"] as? [[String: Any]],
|
||||
let first = dataArray.first,
|
||||
let embedding = first["embedding"] as? [Double] else {
|
||||
throw EmbeddingError.invalidResponse
|
||||
}
|
||||
|
||||
return embedding.map { Float($0) }
|
||||
}
|
||||
|
||||
/// Generate Google embedding
|
||||
private func generateGoogleEmbedding(text: String, model: String) async throws -> [Float] {
|
||||
guard let apiKey = settings.googleAPIKey, !apiKey.isEmpty else {
|
||||
throw EmbeddingError.missingAPIKey("Google")
|
||||
}
|
||||
|
||||
let url = URL(string: "https://generativelanguage.googleapis.com/v1beta/models/\(model):embedContent?key=\(apiKey)")!
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "POST"
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
|
||||
let body: [String: Any] = [
|
||||
"content": [
|
||||
"parts": [
|
||||
["text": text]
|
||||
]
|
||||
]
|
||||
]
|
||||
request.httpBody = try JSONSerialization.data(withJSONObject: body)
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
throw EmbeddingError.invalidResponse
|
||||
}
|
||||
|
||||
guard httpResponse.statusCode == 200 else {
|
||||
let errorMessage = String(data: data, encoding: .utf8) ?? "Unknown error"
|
||||
Log.api.error("Google embedding error (\(httpResponse.statusCode)): \(errorMessage)")
|
||||
throw EmbeddingError.apiError(httpResponse.statusCode, errorMessage)
|
||||
}
|
||||
|
||||
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
|
||||
guard let embedding = json?["embedding"] as? [String: Any],
|
||||
let values = embedding["values"] as? [Double] else {
|
||||
throw EmbeddingError.invalidResponse
|
||||
}
|
||||
|
||||
return values.map { Float($0) }
|
||||
}
|
||||
|
||||
// MARK: - Similarity Calculation
|
||||
|
||||
/// Calculate cosine similarity between two embeddings
|
||||
func cosineSimilarity(_ a: [Float], _ b: [Float]) -> Float {
|
||||
guard a.count == b.count else {
|
||||
Log.api.error("Embedding dimension mismatch: \(a.count) vs \(b.count)")
|
||||
return 0.0
|
||||
}
|
||||
|
||||
var dotProduct: Float = 0.0
|
||||
var magnitudeA: Float = 0.0
|
||||
var magnitudeB: Float = 0.0
|
||||
|
||||
for i in 0..<a.count {
|
||||
dotProduct += a[i] * b[i]
|
||||
magnitudeA += a[i] * a[i]
|
||||
magnitudeB += b[i] * b[i]
|
||||
}
|
||||
|
||||
magnitudeA = sqrt(magnitudeA)
|
||||
magnitudeB = sqrt(magnitudeB)
|
||||
|
||||
guard magnitudeA > 0 && magnitudeB > 0 else {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return dotProduct / (magnitudeA * magnitudeB)
|
||||
}
|
||||
|
||||
// MARK: - Database Operations
|
||||
|
||||
/// Save message embedding to database
|
||||
func saveMessageEmbedding(messageId: UUID, embedding: [Float], model: String) throws {
|
||||
let data = serializeEmbedding(embedding)
|
||||
try DatabaseService.shared.saveMessageEmbedding(
|
||||
messageId: messageId,
|
||||
embedding: data,
|
||||
model: model,
|
||||
dimension: embedding.count
|
||||
)
|
||||
}
|
||||
|
||||
/// Get message embedding from database
|
||||
func getMessageEmbedding(messageId: UUID) throws -> [Float]? {
|
||||
guard let data = try DatabaseService.shared.getMessageEmbedding(messageId: messageId) else {
|
||||
return nil
|
||||
}
|
||||
return deserializeEmbedding(data)
|
||||
}
|
||||
|
||||
/// Save conversation embedding to database
|
||||
func saveConversationEmbedding(conversationId: UUID, embedding: [Float], model: String) throws {
|
||||
let data = serializeEmbedding(embedding)
|
||||
try DatabaseService.shared.saveConversationEmbedding(
|
||||
conversationId: conversationId,
|
||||
embedding: data,
|
||||
model: model,
|
||||
dimension: embedding.count
|
||||
)
|
||||
}
|
||||
|
||||
/// Get conversation embedding from database
|
||||
func getConversationEmbedding(conversationId: UUID) throws -> [Float]? {
|
||||
guard let data = try DatabaseService.shared.getConversationEmbedding(conversationId: conversationId) else {
|
||||
return nil
|
||||
}
|
||||
return deserializeEmbedding(data)
|
||||
}
|
||||
|
||||
// MARK: - Serialization
|
||||
|
||||
/// Serialize embedding to binary data (4 bytes per float, little-endian)
|
||||
private func serializeEmbedding(_ embedding: [Float]) -> Data {
|
||||
var data = Data(capacity: embedding.count * 4)
|
||||
for value in embedding {
|
||||
var littleEndian = value.bitPattern.littleEndian
|
||||
withUnsafeBytes(of: &littleEndian) { bytes in
|
||||
data.append(contentsOf: bytes)
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
/// Deserialize embedding from binary data
|
||||
private func deserializeEmbedding(_ data: Data) -> [Float] {
|
||||
var embedding: [Float] = []
|
||||
embedding.reserveCapacity(data.count / 4)
|
||||
|
||||
for offset in stride(from: 0, to: data.count, by: 4) {
|
||||
let bytes = data.subdata(in: offset..<(offset + 4))
|
||||
let bitPattern = bytes.withUnsafeBytes { $0.load(as: UInt32.self) }
|
||||
let value = Float(bitPattern: UInt32(littleEndian: bitPattern))
|
||||
embedding.append(value)
|
||||
}
|
||||
|
||||
return embedding
|
||||
}
|
||||
|
||||
// MARK: - Conversation Embedding Generation
|
||||
|
||||
/// Generate embedding for an entire conversation (aggregate of messages)
|
||||
func generateConversationEmbedding(conversationId: UUID) async throws {
|
||||
// Use user's selected provider, or fall back to best available
|
||||
guard let provider = getSelectedProvider() else {
|
||||
throw EmbeddingError.noProvidersAvailable
|
||||
}
|
||||
|
||||
// Load conversation messages
|
||||
guard let (_, messages) = try? DatabaseService.shared.loadConversation(id: conversationId) else {
|
||||
throw EmbeddingError.conversationNotFound
|
||||
}
|
||||
|
||||
// Combine all message content
|
||||
let chatMessages = messages.filter { $0.role == .user || $0.role == .assistant }
|
||||
let combinedText = chatMessages.map { $0.content }.joined(separator: "\n\n")
|
||||
|
||||
// Truncate if too long (8191 tokens max for most embedding models)
|
||||
let truncated = String(combinedText.prefix(30000)) // ~7500 tokens
|
||||
|
||||
// Generate embedding
|
||||
let embedding = try await generateEmbedding(text: truncated, provider: provider)
|
||||
|
||||
// Save to database
|
||||
try saveConversationEmbedding(conversationId: conversationId, embedding: embedding, model: provider.defaultModel)
|
||||
|
||||
Log.api.info("Generated conversation embedding for \(conversationId) using \(provider.displayName) (\(embedding.count) dimensions)")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Errors
|
||||
|
||||
enum EmbeddingError: LocalizedError {
|
||||
case missingAPIKey(String)
|
||||
case invalidResponse
|
||||
case apiError(Int, String)
|
||||
case providerNotImplemented(String)
|
||||
case conversationNotFound
|
||||
case noProvidersAvailable
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .missingAPIKey(let provider):
|
||||
return "\(provider) API key not configured"
|
||||
case .invalidResponse:
|
||||
return "Invalid response from embedding API"
|
||||
case .apiError(let code, let message):
|
||||
return "Embedding API error (\(code)): \(message)"
|
||||
case .providerNotImplemented(let message):
|
||||
return message
|
||||
case .conversationNotFound:
|
||||
return "Conversation not found"
|
||||
case .noProvidersAvailable:
|
||||
return "No embedding providers available. Please configure an API key for OpenAI, OpenRouter, or Google."
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user