Files
oai-swift/oAI/Services/EmbeddingService.swift

427 lines
16 KiB
Swift

//
// EmbeddingService.swift
// oAI
//
// Embedding generation and semantic search
// Supports multiple providers: OpenAI, OpenRouter, Google
//
// SPDX-License-Identifier: AGPL-3.0-or-later
// Copyright (C) 2026 Rune Olsen
//
// This file is part of oAI.
//
// oAI is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// oAI is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
// or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
// Public License for more details.
//
// You should have received a copy of the GNU Affero General Public
// License along with oAI. If not, see <https://www.gnu.org/licenses/>.
import Foundation
import os
// MARK: - Embedding Provider
enum EmbeddingProvider {
case openai(model: String)
case openrouter(model: String)
case google(model: String)
var defaultModel: String {
switch self {
case .openai: return "text-embedding-3-small"
case .openrouter: return "openai/text-embedding-3-small"
case .google: return "text-embedding-004"
}
}
var dimension: Int {
switch self {
case .openai(let model):
return model == "text-embedding-3-large" ? 3072 : 1536
case .openrouter(let model):
if model.contains("text-embedding-3-large") {
return 3072
} else if model.contains("qwen3-embedding-8b") {
return 8192
} else {
return 1536 // Default for most models
}
case .google:
return 768
}
}
var displayName: String {
switch self {
case .openai: return "OpenAI"
case .openrouter: return "OpenRouter"
case .google: return "Google"
}
}
}
// MARK: - Embedding Service
final class EmbeddingService {
static let shared = EmbeddingService()
private let settings = SettingsService.shared
private init() {}
// MARK: - Provider Detection
/// Get the embedding provider based on user's selection in settings
func getSelectedProvider() -> EmbeddingProvider? {
let selectedModel = settings.embeddingProvider
// Map user's selection to provider
switch selectedModel {
case "openai-small":
guard let key = settings.openaiAPIKey, !key.isEmpty else { return nil }
return .openai(model: "text-embedding-3-small")
case "openai-large":
guard let key = settings.openaiAPIKey, !key.isEmpty else { return nil }
return .openai(model: "text-embedding-3-large")
case "openrouter-openai-small":
guard let key = settings.openrouterAPIKey, !key.isEmpty else { return nil }
return .openrouter(model: "openai/text-embedding-3-small")
case "openrouter-openai-large":
guard let key = settings.openrouterAPIKey, !key.isEmpty else { return nil }
return .openrouter(model: "openai/text-embedding-3-large")
case "openrouter-qwen":
guard let key = settings.openrouterAPIKey, !key.isEmpty else { return nil }
return .openrouter(model: "qwen/qwen3-embedding-8b")
case "google-gemini":
guard let key = settings.googleAPIKey, !key.isEmpty else { return nil }
return .google(model: "text-embedding-004")
default:
// Fall back to best available
return getBestAvailableProvider()
}
}
/// Get the best available embedding provider based on user's API keys (priority: OpenAI OpenRouter Google)
func getBestAvailableProvider() -> EmbeddingProvider? {
// Priority: OpenAI OpenRouter Google
if let key = settings.openaiAPIKey, !key.isEmpty {
return .openai(model: "text-embedding-3-small")
}
if let key = settings.openrouterAPIKey, !key.isEmpty {
return .openrouter(model: "openai/text-embedding-3-small")
}
if let key = settings.googleAPIKey, !key.isEmpty {
return .google(model: "text-embedding-004")
}
return nil
}
/// Check if embeddings are available (user has at least one compatible provider)
var isAvailable: Bool {
return getBestAvailableProvider() != nil
}
// MARK: - Embedding Generation
/// Generate embedding for text using the configured provider
func generateEmbedding(text: String, provider: EmbeddingProvider) async throws -> [Float] {
switch provider {
case .openai(let model):
return try await generateOpenAIEmbedding(text: text, model: model)
case .openrouter(let model):
return try await generateOpenRouterEmbedding(text: text, model: model)
case .google(let model):
return try await generateGoogleEmbedding(text: text, model: model)
}
}
/// Generate OpenAI embedding
private func generateOpenAIEmbedding(text: String, model: String) async throws -> [Float] {
guard let apiKey = settings.openaiAPIKey, !apiKey.isEmpty else {
throw EmbeddingError.missingAPIKey("OpenAI")
}
let url = URL(string: "https://api.openai.com/v1/embeddings")!
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
let body: [String: Any] = [
"input": text,
"model": model
]
request.httpBody = try JSONSerialization.data(withJSONObject: body)
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw EmbeddingError.invalidResponse
}
guard httpResponse.statusCode == 200 else {
let errorMessage = String(data: data, encoding: .utf8) ?? "Unknown error"
Log.api.error("OpenAI embedding error (\(httpResponse.statusCode)): \(errorMessage)")
throw EmbeddingError.apiError(httpResponse.statusCode, errorMessage)
}
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
guard let dataArray = json?["data"] as? [[String: Any]],
let first = dataArray.first,
let embedding = first["embedding"] as? [Double] else {
throw EmbeddingError.invalidResponse
}
return embedding.map { Float($0) }
}
/// Generate OpenRouter embedding (OpenAI-compatible API)
private func generateOpenRouterEmbedding(text: String, model: String) async throws -> [Float] {
guard let apiKey = settings.openrouterAPIKey, !apiKey.isEmpty else {
throw EmbeddingError.missingAPIKey("OpenRouter")
}
let url = URL(string: "https://openrouter.ai/api/v1/embeddings")!
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.setValue("https://github.com/yourusername/oAI", forHTTPHeaderField: "HTTP-Referer")
let body: [String: Any] = [
"input": text,
"model": model
]
request.httpBody = try JSONSerialization.data(withJSONObject: body)
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw EmbeddingError.invalidResponse
}
guard httpResponse.statusCode == 200 else {
let errorMessage = String(data: data, encoding: .utf8) ?? "Unknown error"
Log.api.error("OpenRouter embedding error (\(httpResponse.statusCode)): \(errorMessage)")
throw EmbeddingError.apiError(httpResponse.statusCode, errorMessage)
}
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
guard let dataArray = json?["data"] as? [[String: Any]],
let first = dataArray.first,
let embedding = first["embedding"] as? [Double] else {
throw EmbeddingError.invalidResponse
}
return embedding.map { Float($0) }
}
/// Generate Google embedding
private func generateGoogleEmbedding(text: String, model: String) async throws -> [Float] {
guard let apiKey = settings.googleAPIKey, !apiKey.isEmpty else {
throw EmbeddingError.missingAPIKey("Google")
}
let url = URL(string: "https://generativelanguage.googleapis.com/v1beta/models/\(model):embedContent?key=\(apiKey)")!
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
let body: [String: Any] = [
"content": [
"parts": [
["text": text]
]
]
]
request.httpBody = try JSONSerialization.data(withJSONObject: body)
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw EmbeddingError.invalidResponse
}
guard httpResponse.statusCode == 200 else {
let errorMessage = String(data: data, encoding: .utf8) ?? "Unknown error"
Log.api.error("Google embedding error (\(httpResponse.statusCode)): \(errorMessage)")
throw EmbeddingError.apiError(httpResponse.statusCode, errorMessage)
}
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
guard let embedding = json?["embedding"] as? [String: Any],
let values = embedding["values"] as? [Double] else {
throw EmbeddingError.invalidResponse
}
return values.map { Float($0) }
}
// MARK: - Similarity Calculation
/// Calculate cosine similarity between two embeddings
func cosineSimilarity(_ a: [Float], _ b: [Float]) -> Float {
guard a.count == b.count else {
Log.api.error("Embedding dimension mismatch: \(a.count) vs \(b.count)")
return 0.0
}
var dotProduct: Float = 0.0
var magnitudeA: Float = 0.0
var magnitudeB: Float = 0.0
for i in 0..<a.count {
dotProduct += a[i] * b[i]
magnitudeA += a[i] * a[i]
magnitudeB += b[i] * b[i]
}
magnitudeA = sqrt(magnitudeA)
magnitudeB = sqrt(magnitudeB)
guard magnitudeA > 0 && magnitudeB > 0 else {
return 0.0
}
return dotProduct / (magnitudeA * magnitudeB)
}
// MARK: - Database Operations
/// Save message embedding to database
func saveMessageEmbedding(messageId: UUID, embedding: [Float], model: String) throws {
let data = serializeEmbedding(embedding)
try DatabaseService.shared.saveMessageEmbedding(
messageId: messageId,
embedding: data,
model: model,
dimension: embedding.count
)
}
/// Get message embedding from database
func getMessageEmbedding(messageId: UUID) throws -> [Float]? {
guard let data = try DatabaseService.shared.getMessageEmbedding(messageId: messageId) else {
return nil
}
return deserializeEmbedding(data)
}
/// Save conversation embedding to database
func saveConversationEmbedding(conversationId: UUID, embedding: [Float], model: String) throws {
let data = serializeEmbedding(embedding)
try DatabaseService.shared.saveConversationEmbedding(
conversationId: conversationId,
embedding: data,
model: model,
dimension: embedding.count
)
}
/// Get conversation embedding from database
func getConversationEmbedding(conversationId: UUID) throws -> [Float]? {
guard let data = try DatabaseService.shared.getConversationEmbedding(conversationId: conversationId) else {
return nil
}
return deserializeEmbedding(data)
}
// MARK: - Serialization
/// Serialize embedding to binary data (4 bytes per float, little-endian)
private func serializeEmbedding(_ embedding: [Float]) -> Data {
var data = Data(capacity: embedding.count * 4)
for value in embedding {
var littleEndian = value.bitPattern.littleEndian
withUnsafeBytes(of: &littleEndian) { bytes in
data.append(contentsOf: bytes)
}
}
return data
}
/// Deserialize embedding from binary data
private func deserializeEmbedding(_ data: Data) -> [Float] {
var embedding: [Float] = []
embedding.reserveCapacity(data.count / 4)
for offset in stride(from: 0, to: data.count, by: 4) {
let bytes = data.subdata(in: offset..<(offset + 4))
let bitPattern = bytes.withUnsafeBytes { $0.load(as: UInt32.self) }
let value = Float(bitPattern: UInt32(littleEndian: bitPattern))
embedding.append(value)
}
return embedding
}
// MARK: - Conversation Embedding Generation
/// Generate embedding for an entire conversation (aggregate of messages)
func generateConversationEmbedding(conversationId: UUID) async throws {
// Use user's selected provider, or fall back to best available
guard let provider = getSelectedProvider() else {
throw EmbeddingError.noProvidersAvailable
}
// Load conversation messages
guard let (_, messages) = try? DatabaseService.shared.loadConversation(id: conversationId) else {
throw EmbeddingError.conversationNotFound
}
// Combine all message content
let chatMessages = messages.filter { $0.role == .user || $0.role == .assistant }
let combinedText = chatMessages.map { $0.content }.joined(separator: "\n\n")
// Truncate if too long (8191 tokens max for most embedding models)
let truncated = String(combinedText.prefix(30000)) // ~7500 tokens
// Generate embedding
let embedding = try await generateEmbedding(text: truncated, provider: provider)
// Save to database
try saveConversationEmbedding(conversationId: conversationId, embedding: embedding, model: provider.defaultModel)
Log.api.info("Generated conversation embedding for \(conversationId) using \(provider.displayName) (\(embedding.count) dimensions)")
}
}
// MARK: - Errors
enum EmbeddingError: LocalizedError {
case missingAPIKey(String)
case invalidResponse
case apiError(Int, String)
case providerNotImplemented(String)
case conversationNotFound
case noProvidersAvailable
var errorDescription: String? {
switch self {
case .missingAPIKey(let provider):
return "\(provider) API key not configured"
case .invalidResponse:
return "Invalid response from embedding API"
case .apiError(let code, let message):
return "Embedding API error (\(code)): \(message)"
case .providerNotImplemented(let message):
return message
case .conversationNotFound:
return "Conversation not found"
case .noProvidersAvailable:
return "No embedding providers available. Please configure an API key for OpenAI, OpenRouter, or Google."
}
}
}