Files
oai-swift/oAI/Providers/AIProvider.swift

260 lines
7.2 KiB
Swift

//
// AIProvider.swift
// oAI
//
// Protocol for AI provider implementations
//
// SPDX-License-Identifier: AGPL-3.0-or-later
// Copyright (C) 2026 Rune Olsen
//
// This file is part of oAI.
//
// oAI is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// oAI is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
// or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
// Public License for more details.
//
// You should have received a copy of the GNU Affero General Public
// License along with oAI. If not, see <https://www.gnu.org/licenses/>.
import Foundation
// MARK: - Provider Protocol
protocol AIProvider {
var name: String { get }
var capabilities: ProviderCapabilities { get }
func listModels() async throws -> [ModelInfo]
func getModel(_ id: String) async throws -> ModelInfo?
func chat(request: ChatRequest) async throws -> ChatResponse
func streamChat(request: ChatRequest) -> AsyncThrowingStream<StreamChunk, Error>
func getCredits() async throws -> Credits?
/// Chat completion with pre-encoded messages for the MCP tool call loop.
func chatWithToolMessages(model: String, messages: [[String: Any]], tools: [Tool]?, maxTokens: Int?, temperature: Double?) async throws -> ChatResponse
}
// MARK: - Provider Capabilities
struct ProviderCapabilities: Codable {
let supportsStreaming: Bool
let supportsVision: Bool
let supportsTools: Bool
let supportsOnlineSearch: Bool
let maxContextLength: Int?
static let `default` = ProviderCapabilities(
supportsStreaming: true,
supportsVision: false,
supportsTools: false,
supportsOnlineSearch: false,
maxContextLength: nil
)
}
// MARK: - Chat Request
struct ChatRequest {
let messages: [Message]
let model: String
let stream: Bool
let maxTokens: Int?
let temperature: Double?
let topP: Double?
let systemPrompt: String?
let tools: [Tool]?
let onlineMode: Bool
let imageGeneration: Bool
init(
messages: [Message],
model: String,
stream: Bool = true,
maxTokens: Int? = nil,
temperature: Double? = nil,
topP: Double? = nil,
systemPrompt: String? = nil,
tools: [Tool]? = nil,
onlineMode: Bool = false,
imageGeneration: Bool = false
) {
self.messages = messages
self.model = model
self.stream = stream
self.maxTokens = maxTokens
self.temperature = temperature
self.topP = topP
self.systemPrompt = systemPrompt
self.tools = tools
self.onlineMode = onlineMode
self.imageGeneration = imageGeneration
}
}
// MARK: - Chat Response
struct ToolCallInfo {
let id: String
let type: String
let functionName: String
let arguments: String
}
struct ChatResponse: Codable {
let id: String
let model: String
let content: String
let role: String
let finishReason: String?
let usage: Usage?
let created: Date
let toolCalls: [ToolCallInfo]?
let generatedImages: [Data]?
struct Usage: Codable {
let promptTokens: Int
let completionTokens: Int
let totalTokens: Int
enum CodingKeys: String, CodingKey {
case promptTokens = "prompt_tokens"
case completionTokens = "completion_tokens"
case totalTokens = "total_tokens"
}
}
// Custom Codable since ToolCallInfo/generatedImages are not from API directly
enum CodingKeys: String, CodingKey {
case id, model, content, role, finishReason, usage, created
}
init(id: String, model: String, content: String, role: String, finishReason: String?, usage: Usage?, created: Date, toolCalls: [ToolCallInfo]? = nil, generatedImages: [Data]? = nil) {
self.id = id
self.model = model
self.content = content
self.role = role
self.finishReason = finishReason
self.usage = usage
self.created = created
self.toolCalls = toolCalls
self.generatedImages = generatedImages
}
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
id = try container.decode(String.self, forKey: .id)
model = try container.decode(String.self, forKey: .model)
content = try container.decode(String.self, forKey: .content)
role = try container.decode(String.self, forKey: .role)
finishReason = try container.decodeIfPresent(String.self, forKey: .finishReason)
usage = try container.decodeIfPresent(Usage.self, forKey: .usage)
created = try container.decode(Date.self, forKey: .created)
toolCalls = nil
generatedImages = nil
}
}
// MARK: - Stream Chunk
struct StreamChunk {
let id: String
let model: String
let delta: Delta
let finishReason: String?
let usage: ChatResponse.Usage?
struct Delta {
let content: String?
let role: String?
let images: [Data]?
}
var deltaContent: String? {
delta.content
}
}
// MARK: - Tool Definition
struct Tool: Codable {
let type: String
let function: Function
struct Function: Codable {
let name: String
let description: String
let parameters: Parameters
struct Parameters: Codable {
let type: String
let properties: [String: Property]
let required: [String]?
struct Property: Codable {
let type: String
let description: String
let `enum`: [String]?
}
}
}
}
// MARK: - Credits
struct Credits: Codable {
let balance: Double
let currency: String
let usage: Double?
let limit: Double?
var balanceDisplay: String {
String(format: "$%.2f", balance)
}
var usageDisplay: String? {
guard let usage = usage else { return nil }
return String(format: "$%.2f", usage)
}
}
// MARK: - Provider Errors
enum ProviderError: LocalizedError {
case invalidAPIKey
case networkError(Error)
case invalidResponse
case rateLimitExceeded
case modelNotFound(String)
case insufficientCredits
case timeout
case unknown(String)
var errorDescription: String? {
switch self {
case .invalidAPIKey:
return "Invalid API key. Please check your settings."
case .networkError(let error):
return "Network error: \(error.localizedDescription)"
case .invalidResponse:
return "Received invalid response from API"
case .rateLimitExceeded:
return "Rate limit exceeded. Please try again later."
case .modelNotFound(let model):
return "Model '\(model)' not found"
case .insufficientCredits:
return "Insufficient credits"
case .timeout:
return "Request timed out"
case .unknown(let message):
return "Unknown error: \(message)"
}
}
}