Initial commit
This commit is contained in:
298
oAI/Services/AnthropicOAuthService.swift
Normal file
298
oAI/Services/AnthropicOAuthService.swift
Normal file
@@ -0,0 +1,298 @@
|
||||
//
|
||||
// AnthropicOAuthService.swift
|
||||
// oAI
|
||||
//
|
||||
// OAuth 2.0 PKCE flow for Anthropic Pro/Max subscription login
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import CryptoKit
|
||||
import Security
|
||||
|
||||
@Observable
|
||||
class AnthropicOAuthService {
|
||||
static let shared = AnthropicOAuthService()
|
||||
|
||||
// OAuth configuration (matches Claude Code CLI)
|
||||
private let clientId = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
private let redirectURI = "https://console.anthropic.com/oauth/code/callback"
|
||||
private let scope = "org:create_api_key user:profile user:inference"
|
||||
private let tokenEndpoint = "https://console.anthropic.com/v1/oauth/token"
|
||||
|
||||
// Keychain keys
|
||||
private enum Keys {
|
||||
static let accessToken = "com.oai.anthropic.oauth.accessToken"
|
||||
static let refreshToken = "com.oai.anthropic.oauth.refreshToken"
|
||||
static let expiresAt = "com.oai.anthropic.oauth.expiresAt"
|
||||
}
|
||||
|
||||
// PKCE state for current flow
|
||||
private var currentVerifier: String?
|
||||
|
||||
// Observable state
|
||||
var isAuthenticated: Bool { accessToken != nil }
|
||||
var isLoggingIn = false
|
||||
|
||||
// MARK: - Token Access
|
||||
|
||||
var accessToken: String? {
|
||||
getKeychainValue(for: Keys.accessToken)
|
||||
}
|
||||
|
||||
private var refreshToken: String? {
|
||||
getKeychainValue(for: Keys.refreshToken)
|
||||
}
|
||||
|
||||
private var expiresAt: Date? {
|
||||
guard let str = getKeychainValue(for: Keys.expiresAt),
|
||||
let interval = Double(str) else { return nil }
|
||||
return Date(timeIntervalSince1970: interval)
|
||||
}
|
||||
|
||||
var isTokenExpired: Bool {
|
||||
guard let expires = expiresAt else { return true }
|
||||
return Date() >= expires
|
||||
}
|
||||
|
||||
// MARK: - Step 1: Generate Authorization URL
|
||||
|
||||
func generateAuthorizationURL() -> URL {
|
||||
let verifier = generateCodeVerifier()
|
||||
currentVerifier = verifier
|
||||
let challenge = generateCodeChallenge(from: verifier)
|
||||
|
||||
var components = URLComponents(string: "https://claude.ai/oauth/authorize")!
|
||||
components.queryItems = [
|
||||
URLQueryItem(name: "code", value: "true"),
|
||||
URLQueryItem(name: "client_id", value: clientId),
|
||||
URLQueryItem(name: "response_type", value: "code"),
|
||||
URLQueryItem(name: "redirect_uri", value: redirectURI),
|
||||
URLQueryItem(name: "scope", value: scope),
|
||||
URLQueryItem(name: "code_challenge", value: challenge),
|
||||
URLQueryItem(name: "code_challenge_method", value: "S256"),
|
||||
URLQueryItem(name: "state", value: verifier),
|
||||
]
|
||||
|
||||
return components.url!
|
||||
}
|
||||
|
||||
// MARK: - Step 2: Exchange Code for Tokens
|
||||
|
||||
func exchangeCode(_ pastedCode: String) async throws {
|
||||
guard let verifier = currentVerifier else {
|
||||
throw OAuthError.noVerifier
|
||||
}
|
||||
|
||||
// Code format: "auth_code#state"
|
||||
let parts = pastedCode.trimmingCharacters(in: .whitespacesAndNewlines).components(separatedBy: "#")
|
||||
let authCode: String
|
||||
let state: String
|
||||
|
||||
if parts.count >= 2 {
|
||||
authCode = parts[0]
|
||||
state = parts.dropFirst().joined(separator: "#")
|
||||
} else {
|
||||
// If no # separator, treat entire string as the code
|
||||
authCode = pastedCode.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
state = verifier
|
||||
}
|
||||
|
||||
Log.api.info("Exchanging OAuth code for tokens")
|
||||
|
||||
let body: [String: String] = [
|
||||
"code": authCode,
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": clientId,
|
||||
"redirect_uri": redirectURI,
|
||||
"code_verifier": verifier,
|
||||
]
|
||||
|
||||
let tokenResponse = try await postTokenRequest(body)
|
||||
saveTokens(tokenResponse)
|
||||
currentVerifier = nil
|
||||
|
||||
Log.api.info("OAuth login successful, token expires in \(tokenResponse.expiresIn)s")
|
||||
}
|
||||
|
||||
// MARK: - Token Refresh
|
||||
|
||||
func refreshAccessToken() async throws {
|
||||
guard let refresh = refreshToken else {
|
||||
throw OAuthError.noRefreshToken
|
||||
}
|
||||
|
||||
Log.api.info("Refreshing OAuth access token")
|
||||
|
||||
let body: [String: String] = [
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh,
|
||||
"client_id": clientId,
|
||||
]
|
||||
|
||||
let tokenResponse = try await postTokenRequest(body)
|
||||
saveTokens(tokenResponse)
|
||||
|
||||
Log.api.info("OAuth token refreshed successfully")
|
||||
}
|
||||
|
||||
/// Returns a valid access token, refreshing if needed
|
||||
func getValidAccessToken() async throws -> String {
|
||||
guard let token = accessToken else {
|
||||
throw OAuthError.notAuthenticated
|
||||
}
|
||||
|
||||
if isTokenExpired {
|
||||
try await refreshAccessToken()
|
||||
guard let newToken = accessToken else {
|
||||
throw OAuthError.notAuthenticated
|
||||
}
|
||||
return newToken
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
// MARK: - Logout
|
||||
|
||||
func logout() {
|
||||
deleteKeychainValue(for: Keys.accessToken)
|
||||
deleteKeychainValue(for: Keys.refreshToken)
|
||||
deleteKeychainValue(for: Keys.expiresAt)
|
||||
currentVerifier = nil
|
||||
Log.api.info("OAuth logout complete")
|
||||
}
|
||||
|
||||
// MARK: - PKCE Helpers
|
||||
|
||||
private func generateCodeVerifier() -> String {
|
||||
var bytes = [UInt8](repeating: 0, count: 32)
|
||||
_ = SecRandomCopyBytes(kSecRandomDefault, bytes.count, &bytes)
|
||||
return Data(bytes).base64URLEncoded()
|
||||
}
|
||||
|
||||
private func generateCodeChallenge(from verifier: String) -> String {
|
||||
let data = Data(verifier.utf8)
|
||||
let hash = SHA256.hash(data: data)
|
||||
return Data(hash).base64URLEncoded()
|
||||
}
|
||||
|
||||
// MARK: - Token Request
|
||||
|
||||
private func postTokenRequest(_ body: [String: String]) async throws -> TokenResponse {
|
||||
var request = URLRequest(url: URL(string: tokenEndpoint)!)
|
||||
request.httpMethod = "POST"
|
||||
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
request.httpBody = try JSONEncoder().encode(body)
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
throw OAuthError.invalidResponse
|
||||
}
|
||||
|
||||
guard httpResponse.statusCode == 200 else {
|
||||
let errorBody = String(data: data, encoding: .utf8) ?? "Unknown error"
|
||||
Log.api.error("OAuth token exchange failed HTTP \(httpResponse.statusCode): \(errorBody)")
|
||||
throw OAuthError.tokenExchangeFailed(httpResponse.statusCode, errorBody)
|
||||
}
|
||||
|
||||
return try JSONDecoder().decode(TokenResponse.self, from: data)
|
||||
}
|
||||
|
||||
// MARK: - Token Storage
|
||||
|
||||
private func saveTokens(_ response: TokenResponse) {
|
||||
setKeychainValue(response.accessToken, for: Keys.accessToken)
|
||||
if let refresh = response.refreshToken {
|
||||
setKeychainValue(refresh, for: Keys.refreshToken)
|
||||
}
|
||||
let expiresAt = Date().addingTimeInterval(TimeInterval(response.expiresIn))
|
||||
setKeychainValue(String(expiresAt.timeIntervalSince1970), for: Keys.expiresAt)
|
||||
}
|
||||
|
||||
// MARK: - Keychain Helpers
|
||||
|
||||
private func getKeychainValue(for key: String) -> String? {
|
||||
let query: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrAccount as String: key,
|
||||
kSecReturnData as String: true,
|
||||
kSecMatchLimit as String: kSecMatchLimitOne,
|
||||
]
|
||||
var ref: AnyObject?
|
||||
guard SecItemCopyMatching(query as CFDictionary, &ref) == errSecSuccess,
|
||||
let data = ref as? Data,
|
||||
let value = String(data: data, encoding: .utf8) else {
|
||||
return nil
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
private func setKeychainValue(_ value: String, for key: String) {
|
||||
guard let data = value.data(using: .utf8) else { return }
|
||||
let query: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrAccount as String: key,
|
||||
]
|
||||
let attrs: [String: Any] = [kSecValueData as String: data]
|
||||
let status = SecItemUpdate(query as CFDictionary, attrs as CFDictionary)
|
||||
if status == errSecItemNotFound {
|
||||
var newItem = query
|
||||
newItem[kSecValueData as String] = data
|
||||
SecItemAdd(newItem as CFDictionary, nil)
|
||||
}
|
||||
}
|
||||
|
||||
private func deleteKeychainValue(for key: String) {
|
||||
let query: [String: Any] = [
|
||||
kSecClass as String: kSecClassGenericPassword,
|
||||
kSecAttrAccount as String: key,
|
||||
]
|
||||
SecItemDelete(query as CFDictionary)
|
||||
}
|
||||
|
||||
// MARK: - Types
|
||||
|
||||
struct TokenResponse: Decodable {
|
||||
let accessToken: String
|
||||
let refreshToken: String?
|
||||
let expiresIn: Int
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case accessToken = "access_token"
|
||||
case refreshToken = "refresh_token"
|
||||
case expiresIn = "expires_in"
|
||||
}
|
||||
}
|
||||
|
||||
enum OAuthError: LocalizedError {
|
||||
case noVerifier
|
||||
case noRefreshToken
|
||||
case notAuthenticated
|
||||
case invalidResponse
|
||||
case tokenExchangeFailed(Int, String)
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .noVerifier: return "No PKCE verifier — start the login flow first."
|
||||
case .noRefreshToken: return "No refresh token available. Please log in again."
|
||||
case .notAuthenticated: return "Not authenticated. Please log in."
|
||||
case .invalidResponse: return "Invalid response from Anthropic OAuth server."
|
||||
case .tokenExchangeFailed(let code, let body):
|
||||
return "Token exchange failed (HTTP \(code)): \(body)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Base64URL Encoding
|
||||
|
||||
private extension Data {
|
||||
func base64URLEncoded() -> String {
|
||||
base64EncodedString()
|
||||
.replacingOccurrences(of: "+", with: "-")
|
||||
.replacingOccurrences(of: "/", with: "_")
|
||||
.replacingOccurrences(of: "=", with: "")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user