Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 140 additions & 41 deletions Sources/CodingPlanAuth/Infrastructure/Server/LocalCallbackServer.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// LocalCallbackServer.swift
// CodingPlanAuth

import Darwin
import Foundation
import SwiftWebServer

Expand All @@ -25,6 +26,8 @@ private final class WebServerBox: @unchecked Sendable {
}

actor LocalCallbackServer {
private static let ephemeralPortStartupAttempts = 5

nonisolated let port: UInt16
nonisolated let callbackPath: String

Expand All @@ -51,47 +54,7 @@ actor LocalCallbackServer {
func start() async throws -> CallbackParameters {
try await withCheckedThrowingContinuation { continuation in
self.continuation = continuation
let server = WebServerBox(SwiftWebServer())
self.server = server

// Capture state into local constants so the handler closure doesn't
// synchronously reach into the actor.
let redirectBaseURL = self.redirectBaseURL
let responseHTML = self.responseHTML

server.server.get(callbackPath) { [weak self] request, response in
let query = request.queryParameters
guard let code = query["code"] else {
response.status(.badRequest, error: "Missing authorization code")
return
}
let params = CallbackParameters(code: code, state: query["state"])

if let redirectBaseURL {
var components = URLComponents(string: redirectBaseURL)
var items = components?.queryItems ?? []
items.append(URLQueryItem(name: "code", value: code))
if let state = query["state"] {
items.append(URLQueryItem(name: "state", value: state))
}
components?.queryItems = items
response.redirectTemporary(components?.string ?? redirectBaseURL)
} else {
response.status(.ok)
response.header(.contentType, "text/html; charset=utf-8")
response.send(responseHTML)
}

let target = self
Task { await target?.resume(with: params) }
}

let listenPort = self.port
Task { @MainActor in
server.server.listen(UInt(listenPort)) { }
let result = Self.startupResult(from: server.server.status)
Task { await self.finishStartup(with: result) }
}
Task { await self.startListening() }
}
}

Expand Down Expand Up @@ -149,6 +112,142 @@ actor LocalCallbackServer {
}
}

private func startListening() async {
let shouldResolveEphemeralPort = port == 0
let maxAttempts = shouldResolveEphemeralPort ? Self.ephemeralPortStartupAttempts : 1

for attempt in 1...maxAttempts {
guard continuation != nil else { return }

let listenPort: UInt16
do {
listenPort = try Self.resolveListenPort(port)
} catch {
let authError = error as? AuthError ?? .callbackServerError(error.localizedDescription)
finishStartup(with: .failure(authError))
return
}

let server = makeServer()
self.server = server

let result = await MainActor.run {
server.server.listen(UInt(listenPort)) { }
return Self.startupResult(from: server.server.status)
}

if case .failure(let error) = result,
shouldResolveEphemeralPort,
attempt < maxAttempts,
Self.isResolvedPortBindFailure(error) {
self.server = nil
await MainActor.run {
server.server.close()
}
continue
}

finishStartup(with: result)
return
}
}

private func makeServer() -> WebServerBox {
let server = WebServerBox(SwiftWebServer())

// Capture state into local constants so the handler closure doesn't
// synchronously reach into the actor.
let redirectBaseURL = self.redirectBaseURL
let responseHTML = self.responseHTML

server.server.get(callbackPath) { [weak self] request, response in
let query = request.queryParameters
guard let code = query["code"] else {
response.status(.badRequest, error: "Missing authorization code")
return
}
let params = CallbackParameters(code: code, state: query["state"])

if let redirectBaseURL {
var components = URLComponents(string: redirectBaseURL)
var items = components?.queryItems ?? []
items.append(URLQueryItem(name: "code", value: code))
if let state = query["state"] {
items.append(URLQueryItem(name: "state", value: state))
}
components?.queryItems = items
response.redirectTemporary(components?.string ?? redirectBaseURL)
} else {
response.status(.ok)
response.header(.contentType, "text/html; charset=utf-8")
response.send(responseHTML)
}

let target = self
Task { await target?.resume(with: params) }
}

return server
}

private static func isResolvedPortBindFailure(_ error: AuthError) -> Bool {
guard case .callbackServerError(let message) = error else { return false }
return message.hasPrefix("Failed to bind IPv4 socket on port ")
}

private static func resolveListenPort(_ port: UInt16) throws -> UInt16 {
guard port == 0 else { return port }

// SwiftWebServer reports `.running(port: 0)` for OS-assigned ports, so
// reserve a candidate port first and retry startup if that race loses.
let socketDescriptor = socket(AF_INET, SOCK_STREAM, 0)
guard socketDescriptor >= 0 else {
throw AuthError.callbackServerError("Could not create socket to reserve callback port")
}
defer { close(socketDescriptor) }

var reuseAddress = 1
guard setsockopt(
socketDescriptor,
SOL_SOCKET,
SO_REUSEADDR,
&reuseAddress,
socklen_t(MemoryLayout.size(ofValue: reuseAddress))
) == 0 else {
throw AuthError.callbackServerError("Could not configure callback port socket")
}

var address = sockaddr_in()
address.sin_family = sa_family_t(AF_INET)
address.sin_addr.s_addr = INADDR_ANY
address.sin_port = 0

let bindResult = withUnsafePointer(to: &address) { pointer in
pointer.withMemoryRebound(to: sockaddr.self, capacity: 1) {
bind(socketDescriptor, $0, socklen_t(MemoryLayout<sockaddr_in>.size))
}
}
guard bindResult == 0 else {
throw AuthError.callbackServerError("Could not reserve callback port")
}

var length = socklen_t(MemoryLayout<sockaddr_in>.size)
let nameResult = withUnsafeMutablePointer(to: &address) { pointer in
pointer.withMemoryRebound(to: sockaddr.self, capacity: 1) {
getsockname(socketDescriptor, $0, &length)
}
}
guard nameResult == 0 else {
throw AuthError.callbackServerError("Could not read reserved callback port")
}

let reservedPort = UInt16(bigEndian: address.sin_port)
guard reservedPort > 0 else {
throw AuthError.callbackServerError("Reserved callback port was invalid")
}
return reservedPort
}

private func finishStartup(with result: StartupResult) {
switch result {
case .success(let port):
Expand Down
26 changes: 26 additions & 0 deletions Sources/CodingPlanAuth/Presentation/AuthState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ public final class AuthState {
let creds = try await service.credentials(for: providerId)
update(credentials: creds)
} catch {
if shouldClearAuthentication(after: error) {
currentCredentials = nil
isAuthenticated = false
}
setError(error)
}
}
Expand Down Expand Up @@ -105,4 +109,26 @@ public final class AuthState {
self.lastError = .networkError(error.localizedDescription)
}
}

private func shouldClearAuthentication(after error: any Error) -> Bool {
guard let authError = error as? AuthError else { return false }
switch authError {
case .tokenExchangeFailed(let statusCode, let message):
if statusCode == 401 || statusCode == 403 {
return true
}
let normalizedMessage = message.lowercased()
return [
"invalid_grant",
"invalid grant",
"invalid_token",
"invalid token",
"no refresh token",
].contains { normalizedMessage.contains($0) }
case .notAuthenticated, .unsupportedProvider:
return true
default:
return false
}
}
}
53 changes: 37 additions & 16 deletions Sources/CodingPlanAuth/Presentation/BrowserAuthSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,49 @@ public final class BrowserAuthSession: NSObject {
throw AuthError.browserPresentationFailed("A browser authentication session is already active.")
}

return try await withCheckedThrowingContinuation { continuation in
self.continuation = continuation
let session = ASWebAuthenticationSession(
url: url,
callbackURLScheme: callbackScheme
) { [weak self] callbackURL, error in
Task { @MainActor in
self?.finish(callbackURL: callbackURL, error: error)
return try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { continuation in
do {
try Task.checkCancellation()
} catch {
continuation.resume(throwing: AuthError.cancelled)
return
Comment thread
atom2ueki marked this conversation as resolved.
}

self.continuation = continuation
let session = ASWebAuthenticationSession(
url: url,
callbackURLScheme: callbackScheme
) { [weak self] callbackURL, error in
Task { @MainActor in
self?.finish(callbackURL: callbackURL, error: error)
}
}
session.prefersEphemeralWebBrowserSession = false
session.presentationContextProvider = self
self.session = session
if !session.start() {
finish(
callbackURL: nil,
error: AuthError.browserPresentationFailed("The system refused to start the browser authentication session.")
)
}
}
session.prefersEphemeralWebBrowserSession = false
session.presentationContextProvider = self
self.session = session
if !session.start() {
finish(
callbackURL: nil,
error: AuthError.browserPresentationFailed("The system refused to start the browser authentication session.")
)
} onCancel: {
Task { @MainActor [weak self] in
self?.cancelActiveSession()
}
}
}

private func cancelActiveSession() {
guard let continuation else { return }
self.continuation = nil
session?.cancel()
session = nil
continuation.resume(throwing: AuthError.cancelled)
}

private func finish(callbackURL: URL?, error: (any Error)?) {
guard let continuation else { return }
self.continuation = nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,19 +562,29 @@ public struct OpenAICodexClient: Sendable {
}

private static func sseDataPayloads(from stream: String) -> [String] {
stream
.replacingOccurrences(of: "\r\n", with: "\n")
.components(separatedBy: "\n\n")
.compactMap { event in
let dataLines = event
.split(separator: "\n", omittingEmptySubsequences: false)
.compactMap { line -> String? in
guard line.hasPrefix("data:") else { return nil }
return String(line.dropFirst("data:".count)).trimmingCharacters(in: .whitespaces)
}
guard !dataLines.isEmpty else { return nil }
return dataLines.joined(separator: "\n")
var payloads: [String] = []
var current: [String] = []

func flush() {
guard !current.isEmpty else { return }
payloads.append(current.joined(separator: "\n"))
current.removeAll(keepingCapacity: true)
}

let normalized = stream.replacingOccurrences(of: "\r\n", with: "\n")
for line in normalized.split(separator: "\n", omittingEmptySubsequences: false) {
let trimmed = line.trimmingCharacters(in: .whitespacesAndNewlines)
if trimmed.isEmpty || trimmed.hasPrefix("event:") {
flush()
} else if trimmed.hasPrefix("data:") {
current.append(
String(trimmed.dropFirst("data:".count))
.trimmingCharacters(in: .whitespaces)
)
}
}
flush()
return payloads
}

private static func backendErrorMessage(from data: Data) -> String {
Expand Down
Loading
Loading