Files
oai-swift/oAI/Services/AnthropicOAuthService.swift

317 lines
10 KiB
Swift

//
// AnthropicOAuthService.swift
// oAI
//
// OAuth 2.0 PKCE flow for Anthropic Pro/Max subscription login
//
// SPDX-License-Identifier: AGPL-3.0-or-later
// Copyright (C) 2026 Rune Olsen
//
// This file is part of oAI.
//
// oAI is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// oAI is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
// or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
// Public License for more details.
//
// You should have received a copy of the GNU Affero General Public
// License along with oAI. If not, see <https://www.gnu.org/licenses/>.
import Foundation
import CryptoKit
import Security
@Observable
class AnthropicOAuthService {
static let shared = AnthropicOAuthService()
// OAuth configuration (matches Claude Code CLI)
private let clientId = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
private let redirectURI = "https://console.anthropic.com/oauth/code/callback"
private let scope = "org:create_api_key user:profile user:inference"
private let tokenEndpoint = "https://console.anthropic.com/v1/oauth/token"
// Keychain keys
private enum Keys {
static let accessToken = "com.oai.anthropic.oauth.accessToken"
static let refreshToken = "com.oai.anthropic.oauth.refreshToken"
static let expiresAt = "com.oai.anthropic.oauth.expiresAt"
}
// PKCE state for current flow
private var currentVerifier: String?
// Observable state
var isAuthenticated: Bool { accessToken != nil }
var isLoggingIn = false
// MARK: - Token Access
var accessToken: String? {
getKeychainValue(for: Keys.accessToken)
}
private var refreshToken: String? {
getKeychainValue(for: Keys.refreshToken)
}
private var expiresAt: Date? {
guard let str = getKeychainValue(for: Keys.expiresAt),
let interval = Double(str) else { return nil }
return Date(timeIntervalSince1970: interval)
}
var isTokenExpired: Bool {
guard let expires = expiresAt else { return true }
return Date() >= expires
}
// MARK: - Step 1: Generate Authorization URL
func generateAuthorizationURL() -> URL {
let verifier = generateCodeVerifier()
currentVerifier = verifier
let challenge = generateCodeChallenge(from: verifier)
var components = URLComponents(string: "https://claude.ai/oauth/authorize")!
components.queryItems = [
URLQueryItem(name: "code", value: "true"),
URLQueryItem(name: "client_id", value: clientId),
URLQueryItem(name: "response_type", value: "code"),
URLQueryItem(name: "redirect_uri", value: redirectURI),
URLQueryItem(name: "scope", value: scope),
URLQueryItem(name: "code_challenge", value: challenge),
URLQueryItem(name: "code_challenge_method", value: "S256"),
URLQueryItem(name: "state", value: verifier),
]
return components.url!
}
// MARK: - Step 2: Exchange Code for Tokens
func exchangeCode(_ pastedCode: String) async throws {
guard let verifier = currentVerifier else {
throw OAuthError.noVerifier
}
// Code format: "auth_code#state"
let parts = pastedCode.trimmingCharacters(in: .whitespacesAndNewlines).components(separatedBy: "#")
let authCode: String
let state: String
if parts.count >= 2 {
authCode = parts[0]
state = parts.dropFirst().joined(separator: "#")
} else {
// If no # separator, treat entire string as the code
authCode = pastedCode.trimmingCharacters(in: .whitespacesAndNewlines)
state = verifier
}
Log.api.info("Exchanging OAuth code for tokens")
let body: [String: String] = [
"code": authCode,
"state": state,
"grant_type": "authorization_code",
"client_id": clientId,
"redirect_uri": redirectURI,
"code_verifier": verifier,
]
let tokenResponse = try await postTokenRequest(body)
saveTokens(tokenResponse)
currentVerifier = nil
Log.api.info("OAuth login successful, token expires in \(tokenResponse.expiresIn)s")
}
// MARK: - Token Refresh
func refreshAccessToken() async throws {
guard let refresh = refreshToken else {
throw OAuthError.noRefreshToken
}
Log.api.info("Refreshing OAuth access token")
let body: [String: String] = [
"grant_type": "refresh_token",
"refresh_token": refresh,
"client_id": clientId,
]
let tokenResponse = try await postTokenRequest(body)
saveTokens(tokenResponse)
Log.api.info("OAuth token refreshed successfully")
}
/// Returns a valid access token, refreshing if needed
func getValidAccessToken() async throws -> String {
guard let token = accessToken else {
throw OAuthError.notAuthenticated
}
if isTokenExpired {
try await refreshAccessToken()
guard let newToken = accessToken else {
throw OAuthError.notAuthenticated
}
return newToken
}
return token
}
// MARK: - Logout
func logout() {
deleteKeychainValue(for: Keys.accessToken)
deleteKeychainValue(for: Keys.refreshToken)
deleteKeychainValue(for: Keys.expiresAt)
currentVerifier = nil
Log.api.info("OAuth logout complete")
}
// MARK: - PKCE Helpers
private func generateCodeVerifier() -> String {
var bytes = [UInt8](repeating: 0, count: 32)
_ = SecRandomCopyBytes(kSecRandomDefault, bytes.count, &bytes)
return Data(bytes).base64URLEncoded()
}
private func generateCodeChallenge(from verifier: String) -> String {
let data = Data(verifier.utf8)
let hash = SHA256.hash(data: data)
return Data(hash).base64URLEncoded()
}
// MARK: - Token Request
private func postTokenRequest(_ body: [String: String]) async throws -> TokenResponse {
var request = URLRequest(url: URL(string: tokenEndpoint)!)
request.httpMethod = "POST"
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
request.httpBody = try JSONEncoder().encode(body)
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw OAuthError.invalidResponse
}
guard httpResponse.statusCode == 200 else {
let errorBody = String(data: data, encoding: .utf8) ?? "Unknown error"
Log.api.error("OAuth token exchange failed HTTP \(httpResponse.statusCode): \(errorBody)")
throw OAuthError.tokenExchangeFailed(httpResponse.statusCode, errorBody)
}
return try JSONDecoder().decode(TokenResponse.self, from: data)
}
// MARK: - Token Storage
private func saveTokens(_ response: TokenResponse) {
setKeychainValue(response.accessToken, for: Keys.accessToken)
if let refresh = response.refreshToken {
setKeychainValue(refresh, for: Keys.refreshToken)
}
let expiresAt = Date().addingTimeInterval(TimeInterval(response.expiresIn))
setKeychainValue(String(expiresAt.timeIntervalSince1970), for: Keys.expiresAt)
}
// MARK: - Keychain Helpers
private func getKeychainValue(for key: String) -> String? {
let query: [String: Any] = [
kSecClass as String: kSecClassGenericPassword,
kSecAttrAccount as String: key,
kSecReturnData as String: true,
kSecMatchLimit as String: kSecMatchLimitOne,
]
var ref: AnyObject?
guard SecItemCopyMatching(query as CFDictionary, &ref) == errSecSuccess,
let data = ref as? Data,
let value = String(data: data, encoding: .utf8) else {
return nil
}
return value
}
private func setKeychainValue(_ value: String, for key: String) {
guard let data = value.data(using: .utf8) else { return }
let query: [String: Any] = [
kSecClass as String: kSecClassGenericPassword,
kSecAttrAccount as String: key,
]
let attrs: [String: Any] = [kSecValueData as String: data]
let status = SecItemUpdate(query as CFDictionary, attrs as CFDictionary)
if status == errSecItemNotFound {
var newItem = query
newItem[kSecValueData as String] = data
SecItemAdd(newItem as CFDictionary, nil)
}
}
private func deleteKeychainValue(for key: String) {
let query: [String: Any] = [
kSecClass as String: kSecClassGenericPassword,
kSecAttrAccount as String: key,
]
SecItemDelete(query as CFDictionary)
}
// MARK: - Types
struct TokenResponse: Decodable {
let accessToken: String
let refreshToken: String?
let expiresIn: Int
enum CodingKeys: String, CodingKey {
case accessToken = "access_token"
case refreshToken = "refresh_token"
case expiresIn = "expires_in"
}
}
enum OAuthError: LocalizedError {
case noVerifier
case noRefreshToken
case notAuthenticated
case invalidResponse
case tokenExchangeFailed(Int, String)
var errorDescription: String? {
switch self {
case .noVerifier: return "No PKCE verifier — start the login flow first."
case .noRefreshToken: return "No refresh token available. Please log in again."
case .notAuthenticated: return "Not authenticated. Please log in."
case .invalidResponse: return "Invalid response from Anthropic OAuth server."
case .tokenExchangeFailed(let code, let body):
return "Token exchange failed (HTTP \(code)): \(body)"
}
}
}
}
// MARK: - Base64URL Encoding
private extension Data {
func base64URLEncoded() -> String {
base64EncodedString()
.replacingOccurrences(of: "+", with: "-")
.replacingOccurrences(of: "/", with: "_")
.replacingOccurrences(of: "=", with: "")
}
}