317 lines
10 KiB
Swift
317 lines
10 KiB
Swift
//
|
|
// AnthropicOAuthService.swift
|
|
// oAI
|
|
//
|
|
// OAuth 2.0 PKCE flow for Anthropic Pro/Max subscription login
|
|
//
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
|
// Copyright (C) 2026 Rune Olsen
|
|
//
|
|
// This file is part of oAI.
|
|
//
|
|
// oAI is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as
|
|
// published by the Free Software Foundation, either version 3 of the
|
|
// License, or (at your option) any later version.
|
|
//
|
|
// oAI is distributed in the hope that it will be useful, but WITHOUT
|
|
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
|
|
// or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General
|
|
// Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public
|
|
// License along with oAI. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
|
|
import Foundation
|
|
import CryptoKit
|
|
import Security
|
|
|
|
@Observable
|
|
class AnthropicOAuthService {
|
|
static let shared = AnthropicOAuthService()
|
|
|
|
// OAuth configuration (matches Claude Code CLI)
|
|
private let clientId = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
|
private let redirectURI = "https://console.anthropic.com/oauth/code/callback"
|
|
private let scope = "org:create_api_key user:profile user:inference"
|
|
private let tokenEndpoint = "https://console.anthropic.com/v1/oauth/token"
|
|
|
|
// Keychain keys
|
|
private enum Keys {
|
|
static let accessToken = "com.oai.anthropic.oauth.accessToken"
|
|
static let refreshToken = "com.oai.anthropic.oauth.refreshToken"
|
|
static let expiresAt = "com.oai.anthropic.oauth.expiresAt"
|
|
}
|
|
|
|
// PKCE state for current flow
|
|
private var currentVerifier: String?
|
|
|
|
// Observable state
|
|
var isAuthenticated: Bool { accessToken != nil }
|
|
var isLoggingIn = false
|
|
|
|
// MARK: - Token Access
|
|
|
|
var accessToken: String? {
|
|
getKeychainValue(for: Keys.accessToken)
|
|
}
|
|
|
|
private var refreshToken: String? {
|
|
getKeychainValue(for: Keys.refreshToken)
|
|
}
|
|
|
|
private var expiresAt: Date? {
|
|
guard let str = getKeychainValue(for: Keys.expiresAt),
|
|
let interval = Double(str) else { return nil }
|
|
return Date(timeIntervalSince1970: interval)
|
|
}
|
|
|
|
var isTokenExpired: Bool {
|
|
guard let expires = expiresAt else { return true }
|
|
return Date() >= expires
|
|
}
|
|
|
|
// MARK: - Step 1: Generate Authorization URL
|
|
|
|
func generateAuthorizationURL() -> URL {
|
|
let verifier = generateCodeVerifier()
|
|
currentVerifier = verifier
|
|
let challenge = generateCodeChallenge(from: verifier)
|
|
|
|
var components = URLComponents(string: "https://claude.ai/oauth/authorize")!
|
|
components.queryItems = [
|
|
URLQueryItem(name: "code", value: "true"),
|
|
URLQueryItem(name: "client_id", value: clientId),
|
|
URLQueryItem(name: "response_type", value: "code"),
|
|
URLQueryItem(name: "redirect_uri", value: redirectURI),
|
|
URLQueryItem(name: "scope", value: scope),
|
|
URLQueryItem(name: "code_challenge", value: challenge),
|
|
URLQueryItem(name: "code_challenge_method", value: "S256"),
|
|
URLQueryItem(name: "state", value: verifier),
|
|
]
|
|
|
|
return components.url!
|
|
}
|
|
|
|
// MARK: - Step 2: Exchange Code for Tokens
|
|
|
|
func exchangeCode(_ pastedCode: String) async throws {
|
|
guard let verifier = currentVerifier else {
|
|
throw OAuthError.noVerifier
|
|
}
|
|
|
|
// Code format: "auth_code#state"
|
|
let parts = pastedCode.trimmingCharacters(in: .whitespacesAndNewlines).components(separatedBy: "#")
|
|
let authCode: String
|
|
let state: String
|
|
|
|
if parts.count >= 2 {
|
|
authCode = parts[0]
|
|
state = parts.dropFirst().joined(separator: "#")
|
|
} else {
|
|
// If no # separator, treat entire string as the code
|
|
authCode = pastedCode.trimmingCharacters(in: .whitespacesAndNewlines)
|
|
state = verifier
|
|
}
|
|
|
|
Log.api.info("Exchanging OAuth code for tokens")
|
|
|
|
let body: [String: String] = [
|
|
"code": authCode,
|
|
"state": state,
|
|
"grant_type": "authorization_code",
|
|
"client_id": clientId,
|
|
"redirect_uri": redirectURI,
|
|
"code_verifier": verifier,
|
|
]
|
|
|
|
let tokenResponse = try await postTokenRequest(body)
|
|
saveTokens(tokenResponse)
|
|
currentVerifier = nil
|
|
|
|
Log.api.info("OAuth login successful, token expires in \(tokenResponse.expiresIn)s")
|
|
}
|
|
|
|
// MARK: - Token Refresh
|
|
|
|
func refreshAccessToken() async throws {
|
|
guard let refresh = refreshToken else {
|
|
throw OAuthError.noRefreshToken
|
|
}
|
|
|
|
Log.api.info("Refreshing OAuth access token")
|
|
|
|
let body: [String: String] = [
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh,
|
|
"client_id": clientId,
|
|
]
|
|
|
|
let tokenResponse = try await postTokenRequest(body)
|
|
saveTokens(tokenResponse)
|
|
|
|
Log.api.info("OAuth token refreshed successfully")
|
|
}
|
|
|
|
/// Returns a valid access token, refreshing if needed
|
|
func getValidAccessToken() async throws -> String {
|
|
guard let token = accessToken else {
|
|
throw OAuthError.notAuthenticated
|
|
}
|
|
|
|
if isTokenExpired {
|
|
try await refreshAccessToken()
|
|
guard let newToken = accessToken else {
|
|
throw OAuthError.notAuthenticated
|
|
}
|
|
return newToken
|
|
}
|
|
|
|
return token
|
|
}
|
|
|
|
// MARK: - Logout
|
|
|
|
func logout() {
|
|
deleteKeychainValue(for: Keys.accessToken)
|
|
deleteKeychainValue(for: Keys.refreshToken)
|
|
deleteKeychainValue(for: Keys.expiresAt)
|
|
currentVerifier = nil
|
|
Log.api.info("OAuth logout complete")
|
|
}
|
|
|
|
// MARK: - PKCE Helpers
|
|
|
|
private func generateCodeVerifier() -> String {
|
|
var bytes = [UInt8](repeating: 0, count: 32)
|
|
_ = SecRandomCopyBytes(kSecRandomDefault, bytes.count, &bytes)
|
|
return Data(bytes).base64URLEncoded()
|
|
}
|
|
|
|
private func generateCodeChallenge(from verifier: String) -> String {
|
|
let data = Data(verifier.utf8)
|
|
let hash = SHA256.hash(data: data)
|
|
return Data(hash).base64URLEncoded()
|
|
}
|
|
|
|
// MARK: - Token Request
|
|
|
|
private func postTokenRequest(_ body: [String: String]) async throws -> TokenResponse {
|
|
var request = URLRequest(url: URL(string: tokenEndpoint)!)
|
|
request.httpMethod = "POST"
|
|
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
|
request.httpBody = try JSONEncoder().encode(body)
|
|
|
|
let (data, response) = try await URLSession.shared.data(for: request)
|
|
|
|
guard let httpResponse = response as? HTTPURLResponse else {
|
|
throw OAuthError.invalidResponse
|
|
}
|
|
|
|
guard httpResponse.statusCode == 200 else {
|
|
let errorBody = String(data: data, encoding: .utf8) ?? "Unknown error"
|
|
Log.api.error("OAuth token exchange failed HTTP \(httpResponse.statusCode): \(errorBody)")
|
|
throw OAuthError.tokenExchangeFailed(httpResponse.statusCode, errorBody)
|
|
}
|
|
|
|
return try JSONDecoder().decode(TokenResponse.self, from: data)
|
|
}
|
|
|
|
// MARK: - Token Storage
|
|
|
|
private func saveTokens(_ response: TokenResponse) {
|
|
setKeychainValue(response.accessToken, for: Keys.accessToken)
|
|
if let refresh = response.refreshToken {
|
|
setKeychainValue(refresh, for: Keys.refreshToken)
|
|
}
|
|
let expiresAt = Date().addingTimeInterval(TimeInterval(response.expiresIn))
|
|
setKeychainValue(String(expiresAt.timeIntervalSince1970), for: Keys.expiresAt)
|
|
}
|
|
|
|
// MARK: - Keychain Helpers
|
|
|
|
private func getKeychainValue(for key: String) -> String? {
|
|
let query: [String: Any] = [
|
|
kSecClass as String: kSecClassGenericPassword,
|
|
kSecAttrAccount as String: key,
|
|
kSecReturnData as String: true,
|
|
kSecMatchLimit as String: kSecMatchLimitOne,
|
|
]
|
|
var ref: AnyObject?
|
|
guard SecItemCopyMatching(query as CFDictionary, &ref) == errSecSuccess,
|
|
let data = ref as? Data,
|
|
let value = String(data: data, encoding: .utf8) else {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|
|
|
|
private func setKeychainValue(_ value: String, for key: String) {
|
|
guard let data = value.data(using: .utf8) else { return }
|
|
let query: [String: Any] = [
|
|
kSecClass as String: kSecClassGenericPassword,
|
|
kSecAttrAccount as String: key,
|
|
]
|
|
let attrs: [String: Any] = [kSecValueData as String: data]
|
|
let status = SecItemUpdate(query as CFDictionary, attrs as CFDictionary)
|
|
if status == errSecItemNotFound {
|
|
var newItem = query
|
|
newItem[kSecValueData as String] = data
|
|
SecItemAdd(newItem as CFDictionary, nil)
|
|
}
|
|
}
|
|
|
|
private func deleteKeychainValue(for key: String) {
|
|
let query: [String: Any] = [
|
|
kSecClass as String: kSecClassGenericPassword,
|
|
kSecAttrAccount as String: key,
|
|
]
|
|
SecItemDelete(query as CFDictionary)
|
|
}
|
|
|
|
// MARK: - Types
|
|
|
|
struct TokenResponse: Decodable {
|
|
let accessToken: String
|
|
let refreshToken: String?
|
|
let expiresIn: Int
|
|
|
|
enum CodingKeys: String, CodingKey {
|
|
case accessToken = "access_token"
|
|
case refreshToken = "refresh_token"
|
|
case expiresIn = "expires_in"
|
|
}
|
|
}
|
|
|
|
enum OAuthError: LocalizedError {
|
|
case noVerifier
|
|
case noRefreshToken
|
|
case notAuthenticated
|
|
case invalidResponse
|
|
case tokenExchangeFailed(Int, String)
|
|
|
|
var errorDescription: String? {
|
|
switch self {
|
|
case .noVerifier: return "No PKCE verifier — start the login flow first."
|
|
case .noRefreshToken: return "No refresh token available. Please log in again."
|
|
case .notAuthenticated: return "Not authenticated. Please log in."
|
|
case .invalidResponse: return "Invalid response from Anthropic OAuth server."
|
|
case .tokenExchangeFailed(let code, let body):
|
|
return "Token exchange failed (HTTP \(code)): \(body)"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: - Base64URL Encoding
|
|
|
|
private extension Data {
|
|
func base64URLEncoded() -> String {
|
|
base64EncodedString()
|
|
.replacingOccurrences(of: "+", with: "-")
|
|
.replacingOccurrences(of: "/", with: "_")
|
|
.replacingOccurrences(of: "=", with: "")
|
|
}
|
|
}
|