427 lines
16 KiB
Swift
427 lines
16 KiB
Swift
//
|
|
// EmbeddingService.swift
|
|
// oAI
|
|
//
|
|
// Embedding generation and semantic search
|
|
// Supports multiple providers: OpenAI, OpenRouter, Google
|
|
//
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
|
// Copyright (C) 2026 Rune Olsen
|
|
//
|
|
// This file is part of oAI.
|
|
//
|
|
// oAI is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as
|
|
// published by the Free Software Foundation, either version 3 of the
|
|
// License, or (at your option) any later version.
|
|
//
|
|
// oAI is distributed in the hope that it will be useful, but WITHOUT
|
|
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
|
|
// or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
|
|
// Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public
|
|
// License along with oAI. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
|
|
import Foundation
|
|
import os
|
|
|
|
// MARK: - Embedding Provider
|
|
|
|
enum EmbeddingProvider {
|
|
case openai(model: String)
|
|
case openrouter(model: String)
|
|
case google(model: String)
|
|
|
|
var defaultModel: String {
|
|
switch self {
|
|
case .openai: return "text-embedding-3-small"
|
|
case .openrouter: return "openai/text-embedding-3-small"
|
|
case .google: return "text-embedding-004"
|
|
}
|
|
}
|
|
|
|
var dimension: Int {
|
|
switch self {
|
|
case .openai(let model):
|
|
return model == "text-embedding-3-large" ? 3072 : 1536
|
|
case .openrouter(let model):
|
|
if model.contains("text-embedding-3-large") {
|
|
return 3072
|
|
} else if model.contains("qwen3-embedding-8b") {
|
|
return 8192
|
|
} else {
|
|
return 1536 // Default for most models
|
|
}
|
|
case .google:
|
|
return 768
|
|
}
|
|
}
|
|
|
|
var displayName: String {
|
|
switch self {
|
|
case .openai: return "OpenAI"
|
|
case .openrouter: return "OpenRouter"
|
|
case .google: return "Google"
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: - Embedding Service
|
|
|
|
final class EmbeddingService {
|
|
static let shared = EmbeddingService()
|
|
|
|
private let settings = SettingsService.shared
|
|
|
|
private init() {}
|
|
|
|
// MARK: - Provider Detection
|
|
|
|
/// Get the embedding provider based on user's selection in settings
|
|
func getSelectedProvider() -> EmbeddingProvider? {
|
|
let selectedModel = settings.embeddingProvider
|
|
|
|
// Map user's selection to provider
|
|
switch selectedModel {
|
|
case "openai-small":
|
|
guard let key = settings.openaiAPIKey, !key.isEmpty else { return nil }
|
|
return .openai(model: "text-embedding-3-small")
|
|
case "openai-large":
|
|
guard let key = settings.openaiAPIKey, !key.isEmpty else { return nil }
|
|
return .openai(model: "text-embedding-3-large")
|
|
case "openrouter-openai-small":
|
|
guard let key = settings.openrouterAPIKey, !key.isEmpty else { return nil }
|
|
return .openrouter(model: "openai/text-embedding-3-small")
|
|
case "openrouter-openai-large":
|
|
guard let key = settings.openrouterAPIKey, !key.isEmpty else { return nil }
|
|
return .openrouter(model: "openai/text-embedding-3-large")
|
|
case "openrouter-qwen":
|
|
guard let key = settings.openrouterAPIKey, !key.isEmpty else { return nil }
|
|
return .openrouter(model: "qwen/qwen3-embedding-8b")
|
|
case "google-gemini":
|
|
guard let key = settings.googleAPIKey, !key.isEmpty else { return nil }
|
|
return .google(model: "text-embedding-004")
|
|
default:
|
|
// Fall back to best available
|
|
return getBestAvailableProvider()
|
|
}
|
|
}
|
|
|
|
/// Get the best available embedding provider based on user's API keys (priority: OpenAI → OpenRouter → Google)
|
|
func getBestAvailableProvider() -> EmbeddingProvider? {
|
|
// Priority: OpenAI → OpenRouter → Google
|
|
if let key = settings.openaiAPIKey, !key.isEmpty {
|
|
return .openai(model: "text-embedding-3-small")
|
|
}
|
|
|
|
if let key = settings.openrouterAPIKey, !key.isEmpty {
|
|
return .openrouter(model: "openai/text-embedding-3-small")
|
|
}
|
|
|
|
if let key = settings.googleAPIKey, !key.isEmpty {
|
|
return .google(model: "text-embedding-004")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
/// Check if embeddings are available (user has at least one compatible provider)
|
|
var isAvailable: Bool {
|
|
return getBestAvailableProvider() != nil
|
|
}
|
|
|
|
// MARK: - Embedding Generation
|
|
|
|
/// Generate embedding for text using the configured provider
|
|
func generateEmbedding(text: String, provider: EmbeddingProvider) async throws -> [Float] {
|
|
switch provider {
|
|
case .openai(let model):
|
|
return try await generateOpenAIEmbedding(text: text, model: model)
|
|
case .openrouter(let model):
|
|
return try await generateOpenRouterEmbedding(text: text, model: model)
|
|
case .google(let model):
|
|
return try await generateGoogleEmbedding(text: text, model: model)
|
|
}
|
|
}
|
|
|
|
/// Generate OpenAI embedding
|
|
private func generateOpenAIEmbedding(text: String, model: String) async throws -> [Float] {
|
|
guard let apiKey = settings.openaiAPIKey, !apiKey.isEmpty else {
|
|
throw EmbeddingError.missingAPIKey("OpenAI")
|
|
}
|
|
|
|
let url = URL(string: "https://api.openai.com/v1/embeddings")!
|
|
var request = URLRequest(url: url)
|
|
request.httpMethod = "POST"
|
|
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
|
|
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
|
|
|
let body: [String: Any] = [
|
|
"input": text,
|
|
"model": model
|
|
]
|
|
request.httpBody = try JSONSerialization.data(withJSONObject: body)
|
|
|
|
let (data, response) = try await URLSession.shared.data(for: request)
|
|
|
|
guard let httpResponse = response as? HTTPURLResponse else {
|
|
throw EmbeddingError.invalidResponse
|
|
}
|
|
|
|
guard httpResponse.statusCode == 200 else {
|
|
let errorMessage = String(data: data, encoding: .utf8) ?? "Unknown error"
|
|
Log.api.error("OpenAI embedding error (\(httpResponse.statusCode)): \(errorMessage)")
|
|
throw EmbeddingError.apiError(httpResponse.statusCode, errorMessage)
|
|
}
|
|
|
|
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
|
|
guard let dataArray = json?["data"] as? [[String: Any]],
|
|
let first = dataArray.first,
|
|
let embedding = first["embedding"] as? [Double] else {
|
|
throw EmbeddingError.invalidResponse
|
|
}
|
|
|
|
return embedding.map { Float($0) }
|
|
}
|
|
|
|
/// Generate OpenRouter embedding (OpenAI-compatible API)
|
|
private func generateOpenRouterEmbedding(text: String, model: String) async throws -> [Float] {
|
|
guard let apiKey = settings.openrouterAPIKey, !apiKey.isEmpty else {
|
|
throw EmbeddingError.missingAPIKey("OpenRouter")
|
|
}
|
|
|
|
let url = URL(string: "https://openrouter.ai/api/v1/embeddings")!
|
|
var request = URLRequest(url: url)
|
|
request.httpMethod = "POST"
|
|
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
|
|
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
|
request.setValue("https://github.com/yourusername/oAI", forHTTPHeaderField: "HTTP-Referer")
|
|
|
|
let body: [String: Any] = [
|
|
"input": text,
|
|
"model": model
|
|
]
|
|
request.httpBody = try JSONSerialization.data(withJSONObject: body)
|
|
|
|
let (data, response) = try await URLSession.shared.data(for: request)
|
|
|
|
guard let httpResponse = response as? HTTPURLResponse else {
|
|
throw EmbeddingError.invalidResponse
|
|
}
|
|
|
|
guard httpResponse.statusCode == 200 else {
|
|
let errorMessage = String(data: data, encoding: .utf8) ?? "Unknown error"
|
|
Log.api.error("OpenRouter embedding error (\(httpResponse.statusCode)): \(errorMessage)")
|
|
throw EmbeddingError.apiError(httpResponse.statusCode, errorMessage)
|
|
}
|
|
|
|
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
|
|
guard let dataArray = json?["data"] as? [[String: Any]],
|
|
let first = dataArray.first,
|
|
let embedding = first["embedding"] as? [Double] else {
|
|
throw EmbeddingError.invalidResponse
|
|
}
|
|
|
|
return embedding.map { Float($0) }
|
|
}
|
|
|
|
/// Generate Google embedding
|
|
private func generateGoogleEmbedding(text: String, model: String) async throws -> [Float] {
|
|
guard let apiKey = settings.googleAPIKey, !apiKey.isEmpty else {
|
|
throw EmbeddingError.missingAPIKey("Google")
|
|
}
|
|
|
|
let url = URL(string: "https://generativelanguage.googleapis.com/v1beta/models/\(model):embedContent?key=\(apiKey)")!
|
|
var request = URLRequest(url: url)
|
|
request.httpMethod = "POST"
|
|
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
|
|
|
let body: [String: Any] = [
|
|
"content": [
|
|
"parts": [
|
|
["text": text]
|
|
]
|
|
]
|
|
]
|
|
request.httpBody = try JSONSerialization.data(withJSONObject: body)
|
|
|
|
let (data, response) = try await URLSession.shared.data(for: request)
|
|
|
|
guard let httpResponse = response as? HTTPURLResponse else {
|
|
throw EmbeddingError.invalidResponse
|
|
}
|
|
|
|
guard httpResponse.statusCode == 200 else {
|
|
let errorMessage = String(data: data, encoding: .utf8) ?? "Unknown error"
|
|
Log.api.error("Google embedding error (\(httpResponse.statusCode)): \(errorMessage)")
|
|
throw EmbeddingError.apiError(httpResponse.statusCode, errorMessage)
|
|
}
|
|
|
|
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
|
|
guard let embedding = json?["embedding"] as? [String: Any],
|
|
let values = embedding["values"] as? [Double] else {
|
|
throw EmbeddingError.invalidResponse
|
|
}
|
|
|
|
return values.map { Float($0) }
|
|
}
|
|
|
|
// MARK: - Similarity Calculation
|
|
|
|
/// Calculate cosine similarity between two embeddings
|
|
func cosineSimilarity(_ a: [Float], _ b: [Float]) -> Float {
|
|
guard a.count == b.count else {
|
|
Log.api.error("Embedding dimension mismatch: \(a.count) vs \(b.count)")
|
|
return 0.0
|
|
}
|
|
|
|
var dotProduct: Float = 0.0
|
|
var magnitudeA: Float = 0.0
|
|
var magnitudeB: Float = 0.0
|
|
|
|
for i in 0..<a.count {
|
|
dotProduct += a[i] * b[i]
|
|
magnitudeA += a[i] * a[i]
|
|
magnitudeB += b[i] * b[i]
|
|
}
|
|
|
|
magnitudeA = sqrt(magnitudeA)
|
|
magnitudeB = sqrt(magnitudeB)
|
|
|
|
guard magnitudeA > 0 && magnitudeB > 0 else {
|
|
return 0.0
|
|
}
|
|
|
|
return dotProduct / (magnitudeA * magnitudeB)
|
|
}
|
|
|
|
// MARK: - Database Operations
|
|
|
|
/// Save message embedding to database
|
|
func saveMessageEmbedding(messageId: UUID, embedding: [Float], model: String) throws {
|
|
let data = serializeEmbedding(embedding)
|
|
try DatabaseService.shared.saveMessageEmbedding(
|
|
messageId: messageId,
|
|
embedding: data,
|
|
model: model,
|
|
dimension: embedding.count
|
|
)
|
|
}
|
|
|
|
/// Get message embedding from database
|
|
func getMessageEmbedding(messageId: UUID) throws -> [Float]? {
|
|
guard let data = try DatabaseService.shared.getMessageEmbedding(messageId: messageId) else {
|
|
return nil
|
|
}
|
|
return deserializeEmbedding(data)
|
|
}
|
|
|
|
/// Save conversation embedding to database
|
|
func saveConversationEmbedding(conversationId: UUID, embedding: [Float], model: String) throws {
|
|
let data = serializeEmbedding(embedding)
|
|
try DatabaseService.shared.saveConversationEmbedding(
|
|
conversationId: conversationId,
|
|
embedding: data,
|
|
model: model,
|
|
dimension: embedding.count
|
|
)
|
|
}
|
|
|
|
/// Get conversation embedding from database
|
|
func getConversationEmbedding(conversationId: UUID) throws -> [Float]? {
|
|
guard let data = try DatabaseService.shared.getConversationEmbedding(conversationId: conversationId) else {
|
|
return nil
|
|
}
|
|
return deserializeEmbedding(data)
|
|
}
|
|
|
|
// MARK: - Serialization
|
|
|
|
/// Serialize embedding to binary data (4 bytes per float, little-endian)
|
|
private func serializeEmbedding(_ embedding: [Float]) -> Data {
|
|
var data = Data(capacity: embedding.count * 4)
|
|
for value in embedding {
|
|
var littleEndian = value.bitPattern.littleEndian
|
|
withUnsafeBytes(of: &littleEndian) { bytes in
|
|
data.append(contentsOf: bytes)
|
|
}
|
|
}
|
|
return data
|
|
}
|
|
|
|
/// Deserialize embedding from binary data
|
|
private func deserializeEmbedding(_ data: Data) -> [Float] {
|
|
var embedding: [Float] = []
|
|
embedding.reserveCapacity(data.count / 4)
|
|
|
|
for offset in stride(from: 0, to: data.count, by: 4) {
|
|
let bytes = data.subdata(in: offset..<(offset + 4))
|
|
let bitPattern = bytes.withUnsafeBytes { $0.load(as: UInt32.self) }
|
|
let value = Float(bitPattern: UInt32(littleEndian: bitPattern))
|
|
embedding.append(value)
|
|
}
|
|
|
|
return embedding
|
|
}
|
|
|
|
// MARK: - Conversation Embedding Generation
|
|
|
|
/// Generate embedding for an entire conversation (aggregate of messages)
|
|
func generateConversationEmbedding(conversationId: UUID) async throws {
|
|
// Use user's selected provider, or fall back to best available
|
|
guard let provider = getSelectedProvider() else {
|
|
throw EmbeddingError.noProvidersAvailable
|
|
}
|
|
|
|
// Load conversation messages
|
|
guard let (_, messages) = try? DatabaseService.shared.loadConversation(id: conversationId) else {
|
|
throw EmbeddingError.conversationNotFound
|
|
}
|
|
|
|
// Combine all message content
|
|
let chatMessages = messages.filter { $0.role == .user || $0.role == .assistant }
|
|
let combinedText = chatMessages.map { $0.content }.joined(separator: "\n\n")
|
|
|
|
// Truncate if too long (8191 tokens max for most embedding models)
|
|
let truncated = String(combinedText.prefix(30000)) // ~7500 tokens
|
|
|
|
// Generate embedding
|
|
let embedding = try await generateEmbedding(text: truncated, provider: provider)
|
|
|
|
// Save to database
|
|
try saveConversationEmbedding(conversationId: conversationId, embedding: embedding, model: provider.defaultModel)
|
|
|
|
Log.api.info("Generated conversation embedding for \(conversationId) using \(provider.displayName) (\(embedding.count) dimensions)")
|
|
}
|
|
}
|
|
|
|
// MARK: - Errors
|
|
|
|
enum EmbeddingError: LocalizedError {
|
|
case missingAPIKey(String)
|
|
case invalidResponse
|
|
case apiError(Int, String)
|
|
case providerNotImplemented(String)
|
|
case conversationNotFound
|
|
case noProvidersAvailable
|
|
|
|
var errorDescription: String? {
|
|
switch self {
|
|
case .missingAPIKey(let provider):
|
|
return "\(provider) API key not configured"
|
|
case .invalidResponse:
|
|
return "Invalid response from embedding API"
|
|
case .apiError(let code, let message):
|
|
return "Embedding API error (\(code)): \(message)"
|
|
case .providerNotImplemented(let message):
|
|
return message
|
|
case .conversationNotFound:
|
|
return "Conversation not found"
|
|
case .noProvidersAvailable:
|
|
return "No embedding providers available. Please configure an API key for OpenAI, OpenRouter, or Google."
|
|
}
|
|
}
|
|
}
|