242 lines
6.5 KiB
Swift
242 lines
6.5 KiB
Swift
//
|
|
// AIProvider.swift
|
|
// oAI
|
|
//
|
|
// Protocol for AI provider implementations
|
|
//
|
|
|
|
import Foundation
|
|
|
|
// MARK: - Provider Protocol
|
|
|
|
protocol AIProvider {
|
|
var name: String { get }
|
|
var capabilities: ProviderCapabilities { get }
|
|
|
|
func listModels() async throws -> [ModelInfo]
|
|
func getModel(_ id: String) async throws -> ModelInfo?
|
|
func chat(request: ChatRequest) async throws -> ChatResponse
|
|
func streamChat(request: ChatRequest) -> AsyncThrowingStream<StreamChunk, Error>
|
|
func getCredits() async throws -> Credits?
|
|
|
|
/// Chat completion with pre-encoded messages for the MCP tool call loop.
|
|
func chatWithToolMessages(model: String, messages: [[String: Any]], tools: [Tool]?, maxTokens: Int?, temperature: Double?) async throws -> ChatResponse
|
|
}
|
|
|
|
// MARK: - Provider Capabilities
|
|
|
|
struct ProviderCapabilities: Codable {
|
|
let supportsStreaming: Bool
|
|
let supportsVision: Bool
|
|
let supportsTools: Bool
|
|
let supportsOnlineSearch: Bool
|
|
let maxContextLength: Int?
|
|
|
|
static let `default` = ProviderCapabilities(
|
|
supportsStreaming: true,
|
|
supportsVision: false,
|
|
supportsTools: false,
|
|
supportsOnlineSearch: false,
|
|
maxContextLength: nil
|
|
)
|
|
}
|
|
|
|
// MARK: - Chat Request
|
|
|
|
struct ChatRequest {
|
|
let messages: [Message]
|
|
let model: String
|
|
let stream: Bool
|
|
let maxTokens: Int?
|
|
let temperature: Double?
|
|
let topP: Double?
|
|
let systemPrompt: String?
|
|
let tools: [Tool]?
|
|
let onlineMode: Bool
|
|
let imageGeneration: Bool
|
|
|
|
init(
|
|
messages: [Message],
|
|
model: String,
|
|
stream: Bool = true,
|
|
maxTokens: Int? = nil,
|
|
temperature: Double? = nil,
|
|
topP: Double? = nil,
|
|
systemPrompt: String? = nil,
|
|
tools: [Tool]? = nil,
|
|
onlineMode: Bool = false,
|
|
imageGeneration: Bool = false
|
|
) {
|
|
self.messages = messages
|
|
self.model = model
|
|
self.stream = stream
|
|
self.maxTokens = maxTokens
|
|
self.temperature = temperature
|
|
self.topP = topP
|
|
self.systemPrompt = systemPrompt
|
|
self.tools = tools
|
|
self.onlineMode = onlineMode
|
|
self.imageGeneration = imageGeneration
|
|
}
|
|
}
|
|
|
|
// MARK: - Chat Response
|
|
|
|
struct ToolCallInfo {
|
|
let id: String
|
|
let type: String
|
|
let functionName: String
|
|
let arguments: String
|
|
}
|
|
|
|
struct ChatResponse: Codable {
|
|
let id: String
|
|
let model: String
|
|
let content: String
|
|
let role: String
|
|
let finishReason: String?
|
|
let usage: Usage?
|
|
let created: Date
|
|
let toolCalls: [ToolCallInfo]?
|
|
let generatedImages: [Data]?
|
|
|
|
struct Usage: Codable {
|
|
let promptTokens: Int
|
|
let completionTokens: Int
|
|
let totalTokens: Int
|
|
|
|
enum CodingKeys: String, CodingKey {
|
|
case promptTokens = "prompt_tokens"
|
|
case completionTokens = "completion_tokens"
|
|
case totalTokens = "total_tokens"
|
|
}
|
|
}
|
|
|
|
// Custom Codable since ToolCallInfo/generatedImages are not from API directly
|
|
enum CodingKeys: String, CodingKey {
|
|
case id, model, content, role, finishReason, usage, created
|
|
}
|
|
|
|
init(id: String, model: String, content: String, role: String, finishReason: String?, usage: Usage?, created: Date, toolCalls: [ToolCallInfo]? = nil, generatedImages: [Data]? = nil) {
|
|
self.id = id
|
|
self.model = model
|
|
self.content = content
|
|
self.role = role
|
|
self.finishReason = finishReason
|
|
self.usage = usage
|
|
self.created = created
|
|
self.toolCalls = toolCalls
|
|
self.generatedImages = generatedImages
|
|
}
|
|
|
|
init(from decoder: Decoder) throws {
|
|
let container = try decoder.container(keyedBy: CodingKeys.self)
|
|
id = try container.decode(String.self, forKey: .id)
|
|
model = try container.decode(String.self, forKey: .model)
|
|
content = try container.decode(String.self, forKey: .content)
|
|
role = try container.decode(String.self, forKey: .role)
|
|
finishReason = try container.decodeIfPresent(String.self, forKey: .finishReason)
|
|
usage = try container.decodeIfPresent(Usage.self, forKey: .usage)
|
|
created = try container.decode(Date.self, forKey: .created)
|
|
toolCalls = nil
|
|
generatedImages = nil
|
|
}
|
|
}
|
|
|
|
// MARK: - Stream Chunk
|
|
|
|
struct StreamChunk {
|
|
let id: String
|
|
let model: String
|
|
let delta: Delta
|
|
let finishReason: String?
|
|
let usage: ChatResponse.Usage?
|
|
|
|
struct Delta {
|
|
let content: String?
|
|
let role: String?
|
|
let images: [Data]?
|
|
}
|
|
|
|
var deltaContent: String? {
|
|
delta.content
|
|
}
|
|
}
|
|
|
|
// MARK: - Tool Definition
|
|
|
|
struct Tool: Codable {
|
|
let type: String
|
|
let function: Function
|
|
|
|
struct Function: Codable {
|
|
let name: String
|
|
let description: String
|
|
let parameters: Parameters
|
|
|
|
struct Parameters: Codable {
|
|
let type: String
|
|
let properties: [String: Property]
|
|
let required: [String]?
|
|
|
|
struct Property: Codable {
|
|
let type: String
|
|
let description: String
|
|
let `enum`: [String]?
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: - Credits
|
|
|
|
struct Credits: Codable {
|
|
let balance: Double
|
|
let currency: String
|
|
let usage: Double?
|
|
let limit: Double?
|
|
|
|
var balanceDisplay: String {
|
|
String(format: "$%.2f", balance)
|
|
}
|
|
|
|
var usageDisplay: String? {
|
|
guard let usage = usage else { return nil }
|
|
return String(format: "$%.2f", usage)
|
|
}
|
|
}
|
|
|
|
// MARK: - Provider Errors
|
|
|
|
enum ProviderError: LocalizedError {
|
|
case invalidAPIKey
|
|
case networkError(Error)
|
|
case invalidResponse
|
|
case rateLimitExceeded
|
|
case modelNotFound(String)
|
|
case insufficientCredits
|
|
case timeout
|
|
case unknown(String)
|
|
|
|
var errorDescription: String? {
|
|
switch self {
|
|
case .invalidAPIKey:
|
|
return "Invalid API key. Please check your settings."
|
|
case .networkError(let error):
|
|
return "Network error: \(error.localizedDescription)"
|
|
case .invalidResponse:
|
|
return "Received invalid response from API"
|
|
case .rateLimitExceeded:
|
|
return "Rate limit exceeded. Please try again later."
|
|
case .modelNotFound(let model):
|
|
return "Model '\(model)' not found"
|
|
case .insufficientCredits:
|
|
return "Insufficient credits"
|
|
case .timeout:
|
|
return "Request timed out"
|
|
case .unknown(let message):
|
|
return "Unknown error: \(message)"
|
|
}
|
|
}
|
|
}
|