Initial commit
This commit is contained in:
308
oAI/Providers/OllamaProvider.swift
Normal file
308
oAI/Providers/OllamaProvider.swift
Normal file
@@ -0,0 +1,308 @@
|
||||
//
|
||||
// OllamaProvider.swift
|
||||
// oAI
|
||||
//
|
||||
// Ollama local AI provider with JSON-lines streaming
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import os
|
||||
|
||||
class OllamaProvider: AIProvider {
|
||||
let name = "Ollama"
|
||||
let capabilities = ProviderCapabilities(
|
||||
supportsStreaming: true,
|
||||
supportsVision: false,
|
||||
supportsTools: false,
|
||||
supportsOnlineSearch: false,
|
||||
maxContextLength: nil
|
||||
)
|
||||
|
||||
private let baseURL: String
|
||||
private let session: URLSession
|
||||
|
||||
init(baseURL: String = "http://localhost:11434") {
|
||||
self.baseURL = baseURL.hasSuffix("/") ? String(baseURL.dropLast()) : baseURL
|
||||
let config = URLSessionConfiguration.default
|
||||
config.timeoutIntervalForRequest = 120
|
||||
config.timeoutIntervalForResource = 600
|
||||
self.session = URLSession(configuration: config)
|
||||
}
|
||||
|
||||
// MARK: - Models
|
||||
|
||||
func listModels() async throws -> [ModelInfo] {
|
||||
Log.api.info("Fetching model list from Ollama at \(self.baseURL)")
|
||||
let url = URL(string: "\(baseURL)/api/tags")!
|
||||
var request = URLRequest(url: url)
|
||||
request.timeoutInterval = 5
|
||||
|
||||
let data: Data
|
||||
let response: URLResponse
|
||||
do {
|
||||
(data, response) = try await session.data(for: request)
|
||||
} catch {
|
||||
Log.api.warning("Cannot connect to Ollama at \(self.baseURL). Is Ollama running?")
|
||||
throw ProviderError.unknown("Cannot connect to Ollama at \(baseURL). Is Ollama running? Start it with: ollama serve")
|
||||
}
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 else {
|
||||
throw ProviderError.unknown("Ollama returned an error. Is it running?")
|
||||
}
|
||||
|
||||
guard let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
||||
let modelsArray = json["models"] as? [[String: Any]] else {
|
||||
return []
|
||||
}
|
||||
|
||||
return modelsArray.compactMap { model -> ModelInfo? in
|
||||
guard let name = model["name"] as? String else { return nil }
|
||||
let sizeBytes = model["size"] as? Int64 ?? 0
|
||||
let sizeGB = String(format: "%.1f GB", Double(sizeBytes) / 1_073_741_824)
|
||||
|
||||
return ModelInfo(
|
||||
id: name,
|
||||
name: name,
|
||||
description: "Local model (\(sizeGB))",
|
||||
contextLength: 0,
|
||||
pricing: .init(prompt: 0, completion: 0),
|
||||
capabilities: .init(vision: false, tools: false, online: false)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func getModel(_ id: String) async throws -> ModelInfo? {
|
||||
let models = try await listModels()
|
||||
return models.first { $0.id == id }
|
||||
}
|
||||
|
||||
// MARK: - Chat Completion
|
||||
|
||||
func chat(request: ChatRequest) async throws -> ChatResponse {
|
||||
Log.api.info("Ollama chat request: model=\(request.model), messages=\(request.messages.count)")
|
||||
let url = URL(string: "\(baseURL)/api/chat")!
|
||||
let body = buildRequestBody(from: request, stream: false)
|
||||
let bodyData = try JSONSerialization.data(withJSONObject: body)
|
||||
|
||||
var urlRequest = URLRequest(url: url)
|
||||
urlRequest.httpMethod = "POST"
|
||||
urlRequest.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
urlRequest.httpBody = bodyData
|
||||
|
||||
let (data, response) = try await session.data(for: urlRequest)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
Log.api.error("Ollama chat: invalid response (not HTTP)")
|
||||
throw ProviderError.invalidResponse
|
||||
}
|
||||
guard httpResponse.statusCode == 200 else {
|
||||
if let errorObj = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
||||
let errorMsg = errorObj["error"] as? String {
|
||||
Log.api.error("Ollama chat HTTP \(httpResponse.statusCode): \(errorMsg)")
|
||||
throw ProviderError.unknown(errorMsg)
|
||||
}
|
||||
Log.api.error("Ollama chat HTTP \(httpResponse.statusCode)")
|
||||
throw ProviderError.unknown("Ollama HTTP \(httpResponse.statusCode)")
|
||||
}
|
||||
|
||||
guard let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] else {
|
||||
throw ProviderError.invalidResponse
|
||||
}
|
||||
|
||||
return parseOllamaResponse(json, model: request.model)
|
||||
}
|
||||
|
||||
// MARK: - Chat with raw tool messages
|
||||
|
||||
func chatWithToolMessages(model: String, messages: [[String: Any]], tools: [Tool]?, maxTokens: Int?, temperature: Double?) async throws -> ChatResponse {
|
||||
// Ollama doesn't support tool calls natively — just send messages as plain chat
|
||||
let url = URL(string: "\(baseURL)/api/chat")!
|
||||
|
||||
// Convert messages, stripping tool-specific fields
|
||||
var ollamaMessages: [[String: Any]] = []
|
||||
for msg in messages {
|
||||
let role = msg["role"] as? String ?? "user"
|
||||
let content = msg["content"] as? String ?? ""
|
||||
|
||||
if role == "tool" {
|
||||
// Convert tool results to assistant context
|
||||
let toolName = msg["name"] as? String ?? "tool"
|
||||
ollamaMessages.append(["role": "user", "content": "[\(toolName) result]: \(content)"])
|
||||
} else if role == "assistant" {
|
||||
// Strip tool_calls, just keep content
|
||||
if let tc = msg["tool_calls"] as? [[String: Any]], !tc.isEmpty {
|
||||
let toolNames = tc.compactMap { ($0["function"] as? [String: Any])?["name"] as? String }
|
||||
let text = (msg["content"] as? String) ?? ""
|
||||
let combined = text.isEmpty ? "Calling: \(toolNames.joined(separator: ", "))" : text
|
||||
ollamaMessages.append(["role": "assistant", "content": combined])
|
||||
} else {
|
||||
ollamaMessages.append(["role": "assistant", "content": content])
|
||||
}
|
||||
} else {
|
||||
ollamaMessages.append(["role": role, "content": content])
|
||||
}
|
||||
}
|
||||
|
||||
var body: [String: Any] = [
|
||||
"model": model,
|
||||
"messages": ollamaMessages,
|
||||
"stream": false
|
||||
]
|
||||
var options: [String: Any] = [:]
|
||||
if let maxTokens = maxTokens { options["num_predict"] = maxTokens }
|
||||
if let temperature = temperature { options["temperature"] = temperature }
|
||||
if !options.isEmpty { body["options"] = options }
|
||||
|
||||
let bodyData = try JSONSerialization.data(withJSONObject: body)
|
||||
|
||||
var urlRequest = URLRequest(url: url)
|
||||
urlRequest.httpMethod = "POST"
|
||||
urlRequest.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
urlRequest.httpBody = bodyData
|
||||
|
||||
let (data, response) = try await session.data(for: urlRequest)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 else {
|
||||
throw ProviderError.unknown("Ollama error")
|
||||
}
|
||||
|
||||
guard let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] else {
|
||||
throw ProviderError.invalidResponse
|
||||
}
|
||||
|
||||
return parseOllamaResponse(json, model: model)
|
||||
}
|
||||
|
||||
// MARK: - Streaming Chat
|
||||
|
||||
func streamChat(request: ChatRequest) -> AsyncThrowingStream<StreamChunk, Error> {
|
||||
Log.api.info("Ollama stream request: model=\(request.model), messages=\(request.messages.count)")
|
||||
return AsyncThrowingStream { continuation in
|
||||
Task {
|
||||
do {
|
||||
let url = URL(string: "\(baseURL)/api/chat")!
|
||||
let body = buildRequestBody(from: request, stream: true)
|
||||
let bodyData = try JSONSerialization.data(withJSONObject: body)
|
||||
|
||||
var urlRequest = URLRequest(url: url)
|
||||
urlRequest.httpMethod = "POST"
|
||||
urlRequest.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
urlRequest.httpBody = bodyData
|
||||
|
||||
let (bytes, response) = try await session.bytes(for: urlRequest)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
Log.api.error("Ollama stream: invalid response (not HTTP)")
|
||||
continuation.finish(throwing: ProviderError.invalidResponse)
|
||||
return
|
||||
}
|
||||
guard httpResponse.statusCode == 200 else {
|
||||
Log.api.error("Ollama stream HTTP \(httpResponse.statusCode)")
|
||||
continuation.finish(throwing: ProviderError.unknown("Ollama HTTP \(httpResponse.statusCode)"))
|
||||
return
|
||||
}
|
||||
|
||||
// Ollama streams JSON lines (one complete JSON object per line)
|
||||
for try await line in bytes.lines {
|
||||
guard !line.isEmpty,
|
||||
let lineData = line.data(using: .utf8),
|
||||
let json = try? JSONSerialization.jsonObject(with: lineData) as? [String: Any] else {
|
||||
continue
|
||||
}
|
||||
|
||||
let done = json["done"] as? Bool ?? false
|
||||
let message = json["message"] as? [String: Any]
|
||||
let content = message?["content"] as? String
|
||||
|
||||
if done {
|
||||
// Final chunk has usage stats
|
||||
let promptTokens = json["prompt_eval_count"] as? Int ?? 0
|
||||
let completionTokens = json["eval_count"] as? Int ?? 0
|
||||
continuation.yield(StreamChunk(
|
||||
id: "",
|
||||
model: request.model,
|
||||
delta: .init(content: content, role: nil, images: nil),
|
||||
finishReason: "stop",
|
||||
usage: ChatResponse.Usage(
|
||||
promptTokens: promptTokens,
|
||||
completionTokens: completionTokens,
|
||||
totalTokens: promptTokens + completionTokens
|
||||
)
|
||||
))
|
||||
continuation.finish()
|
||||
return
|
||||
} else if let content = content {
|
||||
continuation.yield(StreamChunk(
|
||||
id: "",
|
||||
model: request.model,
|
||||
delta: .init(content: content, role: nil, images: nil),
|
||||
finishReason: nil,
|
||||
usage: nil
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
continuation.finish()
|
||||
} catch {
|
||||
continuation.finish(throwing: error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Credits
|
||||
|
||||
func getCredits() async throws -> Credits? {
|
||||
// Local models — no credits needed
|
||||
return nil
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
private func buildRequestBody(from request: ChatRequest, stream: Bool) -> [String: Any] {
|
||||
var messages: [[String: Any]] = []
|
||||
|
||||
// Add system prompt as a system message
|
||||
if let systemPrompt = request.systemPrompt {
|
||||
messages.append(["role": "system", "content": systemPrompt])
|
||||
}
|
||||
|
||||
for msg in request.messages {
|
||||
messages.append(["role": msg.role.rawValue, "content": msg.content])
|
||||
}
|
||||
|
||||
var body: [String: Any] = [
|
||||
"model": request.model,
|
||||
"messages": messages,
|
||||
"stream": stream
|
||||
]
|
||||
|
||||
var options: [String: Any] = [:]
|
||||
if let maxTokens = request.maxTokens { options["num_predict"] = maxTokens }
|
||||
if let temperature = request.temperature { options["temperature"] = temperature }
|
||||
if !options.isEmpty { body["options"] = options }
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
private func parseOllamaResponse(_ json: [String: Any], model: String) -> ChatResponse {
|
||||
let message = json["message"] as? [String: Any]
|
||||
let content = message?["content"] as? String ?? ""
|
||||
let promptTokens = json["prompt_eval_count"] as? Int ?? 0
|
||||
let completionTokens = json["eval_count"] as? Int ?? 0
|
||||
|
||||
return ChatResponse(
|
||||
id: UUID().uuidString,
|
||||
model: model,
|
||||
content: content,
|
||||
role: "assistant",
|
||||
finishReason: "stop",
|
||||
usage: ChatResponse.Usage(
|
||||
promptTokens: promptTokens,
|
||||
completionTokens: completionTokens,
|
||||
totalTokens: promptTokens + completionTokens
|
||||
),
|
||||
created: Date()
|
||||
)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user