diff --git a/Sources/CodableDatastore/Persistence/Disk Persistence/AsyncThrowingBackpressureStream.swift b/Sources/CodableDatastore/Persistence/Disk Persistence/AsyncThrowingBackpressureStream.swift index 1839665..ff1ab91 100644 --- a/Sources/CodableDatastore/Persistence/Disk Persistence/AsyncThrowingBackpressureStream.swift +++ b/Sources/CodableDatastore/Persistence/Disk Persistence/AsyncThrowingBackpressureStream.swift @@ -8,50 +8,98 @@ import Foundation +/// A stream that limits reads based on the speed results are consumed at. +/// +/// A backpressure stream is consumed within a _reading task_, usually in a `for try await` loop. Separately, it is fed within an internal _writing task_ inherited during initialization. These may share the same parent task depending on the use case. +/// +/// Writes to the stream can be made one at a time until they are consumed by the reading task, usually via an iterator. If a write is not consumed, the writing task is suspended until the reading task is ready to consume the event. To this effect, a backpressure stream may hold onto at most a single pending event while waiting for a read to take place. Similarly, if a read happens before a write is ready, the reading task will be suspended, while the write will be processed immediately allowing a follow-up write to be made. +/// +/// The reading task may be cancelled at any time, immediately ending the loop, and propagaing the cancellation to the writing child task, stopping any more values from being provided to the stream. struct AsyncThrowingBackpressureStream: Sendable { fileprivate actor StateMachine { - var pendingEvents: [(CheckedContinuation, Result)] = [] - var eventsReadyContinuation: CheckedContinuation? + var pendingWriteEvents: [(CheckedContinuation, Result)] = [] + var pendingReadContinuation: CheckedContinuation? var wasCancelled = false - func provide(_ result: Result) async throws { - guard !wasCancelled else { throw CancellationError() } + func provide(_ result: Result, in continuation: CheckedContinuation) { + /// If reads were cancelled, propagate the cancellation to the provider without saving the result. + guard !wasCancelled else { + continuation.resume(throwing: CancellationError()) + return + } - try await withCheckedThrowingContinuation { continuation in - precondition(pendingEvents.isEmpty, "More than one event has bee queued on the stream.") - if let eventsReadyContinuation { - self.eventsReadyContinuation = nil - eventsReadyContinuation.resume(with: result) - continuation.resume() - } else { - pendingEvents.append((continuation, result)) - } + /// Ideally, no more than one pending event should be queued up, as a second event means backpressure isn't working. + precondition(pendingWriteEvents.isEmpty, "More than one event has been queued on the stream.") + + /// If a read is currently pending, signal that a new result has been provided. + if let pendingReadContinuation { + self.pendingReadContinuation = nil + pendingReadContinuation.resume(with: result) + continuation.resume() + } else { + /// If we aren't ready for events, queue the event and suspend the task until events are ready. This will stop more values from being provided (ie. the backpressure at work). + pendingWriteEvents.append((continuation, result)) } } + /// Cancel any reads by immediately signalling that no events are available to any pending read. + private func cancelPendingRead() { + wasCancelled = true + if let pendingReadContinuation { + self.pendingReadContinuation = nil + pendingReadContinuation.resume(throwing: CancellationError()) + } + } + + /// Consume the next value on the read task. + /// + /// There are two scenarios to consider here: + /// - A read happens before a write. + /// - A write happens before a read. + /// + /// In the first case, a continuation is saved and the read task is suspended. In the second case, a read is popped off the from of the pending write events and returned immediately. func consumeNext() async throws -> Element? { if Task.isCancelled { wasCancelled = true } - return try await withCheckedThrowingContinuation { continuation in - guard !pendingEvents.isEmpty else { - eventsReadyContinuation = continuation - return - } - let (providerContinuation, result) = pendingEvents.removeFirst() - continuation.resume(with: result) - if wasCancelled { - providerContinuation.resume(throwing: CancellationError()) - } else { - providerContinuation.resume() + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + guard !pendingWriteEvents.isEmpty else { + /// If there are no pending events, suspend the reading task until one is signaled. + guard !wasCancelled else { + /// If the task was cancelled, stop here without waiting for the signal — the provider will be cancelled as soon as they try to provide their first value. + continuation.resume(throwing: CancellationError()) + return + } + pendingReadContinuation = continuation + return + } + + /// Otherwise, pop the first entry off the stack and return it. + let (providerContinuation, result) = pendingWriteEvents.removeFirst() + + /// Return the reading task with the result we have from the write queue. + continuation.resume(with: result) + + /// Determine if the provider should continue providing values or if it should be stopped here. + if wasCancelled { + providerContinuation.resume(throwing: CancellationError()) + } else { + providerContinuation.resume() + } } + } onCancel: { + Task { await cancelPendingRead() } } } deinit { - if let eventsReadyContinuation { - eventsReadyContinuation.resume(throwing: CancellationError()) + if let pendingReadContinuation { + pendingReadContinuation.resume(throwing: CancellationError()) + } + for (providerContinuation, _) in pendingWriteEvents { + providerContinuation.resume(throwing: CancellationError()) } } } @@ -64,17 +112,34 @@ struct AsyncThrowingBackpressureStream: Sendable { } func yield(_ value: Element) async throws { - guard let stateMachine else { throw CancellationError() } - try await stateMachine.provide(.success(value)) + do { + try await withCheckedThrowingContinuation { continuation in + guard let stateMachine else { + continuation.resume(throwing: CancellationError()) + return + } + Task { + await stateMachine.provide(.success(value), in: continuation) + } + } as Void + } catch { + throw error + } } fileprivate func finish(throwing error: Error? = nil) async throws { - guard let stateMachine else { throw CancellationError() } - if let error { - try await stateMachine.provide(.failure(error)) - } else { - try await stateMachine.provide(.success(nil)) - } + try await withCheckedThrowingContinuation { continuation in + guard let stateMachine else { continuation.resume(throwing: CancellationError()) + return + } + Task { + if let error { + await stateMachine.provide(.failure(error), in: continuation) + } else { + await stateMachine.provide(.success(nil), in: continuation) + } + } + } as Void } } @@ -102,6 +167,11 @@ extension AsyncThrowingBackpressureStream: AsyncInstances { func next() async throws -> Element? { try await stateMachine.consumeNext() } + + /// Used only for testing. + internal var wasCancelled: Bool { + get async { await stateMachine.wasCancelled } + } } func makeAsyncIterator() -> AsyncIterator { diff --git a/Tests/CodableDatastoreTests/AsyncThrowingBackpressureStreamTests.swift b/Tests/CodableDatastoreTests/AsyncThrowingBackpressureStreamTests.swift new file mode 100644 index 0000000..5d905ff --- /dev/null +++ b/Tests/CodableDatastoreTests/AsyncThrowingBackpressureStreamTests.swift @@ -0,0 +1,243 @@ +// +// AsyncThrowingBackpressureStreamTests.swift +// CodableDatastore +// +// Created by Dimitri Bouniol on 2026-01-23. +// Copyright © 2023-26 Mochi Development, Inc. All rights reserved. +// + +import XCTest +@testable import CodableDatastore + +final class AsyncThrowingBackpressureStreamTests: XCTestCase { + func testStreamForwardsResults() async throws { + let stream = AsyncThrowingBackpressureStream { continuation in + try await continuation.yield(0) + try await continuation.yield(1) + try await continuation.yield(2) + try await continuation.yield(3) + try await continuation.yield(4) + } + + let results = try await stream.collectInstances(upTo: .infinity) + + XCTAssertEqual(results, [0, 1, 2, 3, 4]) + } + + func testReadTaskSuspendsWriteTask() async throws { + let (writeContinuations, readProvider) = AsyncStream.makeStream(of: (Int, CheckedContinuation).self) + + let stream = AsyncThrowingBackpressureStream { continuation in + try await continuation.yield(0) + await withCheckedContinuation { continuation in + readProvider.yield((0, continuation)) + } + try await continuation.yield(1) + await withCheckedContinuation { continuation in + readProvider.yield((1, continuation)) + } + try await continuation.yield(2) + await withCheckedContinuation { continuation in + readProvider.yield((2, continuation)) + } + try await continuation.yield(3) + await withCheckedContinuation { continuation in + readProvider.yield((3, continuation)) + } + try await continuation.yield(4) + await withCheckedContinuation { continuation in + readProvider.yield((4, continuation)) + } + } + + let iterator = stream.makeAsyncIterator() + var consumer = writeContinuations.makeAsyncIterator() + + var result = try await iterator.next() + XCTAssertEqual(result, 0) + var accumulatedResult = await consumer.next()! + accumulatedResult.1.resume() + XCTAssertEqual(accumulatedResult.0, 0) + + result = try await iterator.next() + XCTAssertEqual(result, 1) + accumulatedResult = await consumer.next()! + accumulatedResult.1.resume() + XCTAssertEqual(accumulatedResult.0, 1) + + result = try await iterator.next() + XCTAssertEqual(result, 2) + accumulatedResult = await consumer.next()! + accumulatedResult.1.resume() + XCTAssertEqual(accumulatedResult.0, 2) + + result = try await iterator.next() + XCTAssertEqual(result, 3) + accumulatedResult = await consumer.next()! + accumulatedResult.1.resume() + XCTAssertEqual(accumulatedResult.0, 3) + + result = try await iterator.next() + XCTAssertEqual(result, 4) + accumulatedResult = await consumer.next()! + accumulatedResult.1.resume() + XCTAssertEqual(accumulatedResult.0, 4) + + result = try await iterator.next() + XCTAssertEqual(result, nil) + } + + func testWriteTaskNeverProgressesWhenReadsDoNotHappen() async throws { + let (writeContinuations, readProvider) = AsyncStream.makeStream(of: (Int, CheckedContinuation).self) + + let stream = AsyncThrowingBackpressureStream { continuation in + try await continuation.yield(0) + await withCheckedContinuation { continuation in + readProvider.yield((0, continuation)) + } + try await continuation.yield(1) + XCTFail() + } + + let iterator = stream.makeAsyncIterator() + var consumer = writeContinuations.makeAsyncIterator() + + let result = try await iterator.next() + XCTAssertEqual(result, 0) + let accumulatedResult = await consumer.next()! + accumulatedResult.1.resume() + XCTAssertEqual(accumulatedResult.0, 0) + + try await Task.sleep(for: .seconds(1)) + } + + func testWriteTaskNeverProgressesWhenReadsAreCancelled() async throws { + let expectation = expectation(description: "Writes were cancelled") + + let task = Task { + let stream = AsyncThrowingBackpressureStream { continuation in + try await continuation.yield(0) + do { + try await continuation.yield(1) + XCTFail() + } catch { + XCTAssertEqual(error is CancellationError, true) + expectation.fulfill() + throw error + } + } + + let iterator = stream.makeAsyncIterator() + let result = try await iterator.next() + XCTAssertEqual(result, 0) + + withUnsafeCurrentTask { task in + task?.cancel() + } + + do { + /// Perform two reads, because we can't control if the write happens before this happens (in which case the first read will succeed) or if it happens after (in which the first read will fail). Either way, the second read will always fail and return nil. + _ = try await iterator.next() + _ = try await iterator.next() + XCTFail() + } catch { + XCTAssertEqual(error is CancellationError, true) + } + } + + try? await task.value + + await fulfillment(of: [expectation], timeout: 10) + } + + func testReadingNotSuspendedWhenCancelledBeforeWrite() async throws { + let (writeContinuations, readProvider) = AsyncStream.makeStream(of: CheckedContinuation.self) + + let expectation = expectation(description: "Writes were cancelled") + + let task = Task { + let stream = AsyncThrowingBackpressureStream { continuation in + try await continuation.yield(0) + await withCheckedContinuation { continuation in + readProvider.yield(continuation) + } + do { + try await continuation.yield(1) + XCTFail() + } catch { + XCTAssertEqual(error is CancellationError, true) + expectation.fulfill() + throw error + } + } + + let iterator = stream.makeAsyncIterator() + let result = try await iterator.next() + XCTAssertEqual(result, 0) + + withUnsafeCurrentTask { task in + task?.cancel() + } + + do { + /// This read is guaranteed to happen before the write, which is blocked below. It should _never_ stall until the write is made. + _ = try await iterator.next() + XCTFail() + } catch { + /// Let the write happen strictly after the read, in its own task so signaling doesn't "see" the cancellation. + XCTAssertEqual(error is CancellationError, true) + await Task { await writeContinuations.first(where: { _ in true })!.resume() }.value + } + } + + try? await task.value + + await fulfillment(of: [expectation], timeout: 10) + } + + func testWritingUnsuspendsWhenReadsCancelledButNeverMade() async throws { + let (writeContinuations, readProvider) = AsyncStream.makeStream(of: CheckedContinuation.self) + + let expectation = expectation(description: "Writes were cancelled") + + let task = Task { + var stream: AsyncThrowingBackpressureStream? = AsyncThrowingBackpressureStream { continuation in + try await continuation.yield(0) + await withCheckedContinuation { continuation in + readProvider.yield(continuation) + } + do { + try await continuation.yield(1) + XCTFail() + expectation.fulfill() + } catch { + XCTAssertEqual(error is CancellationError, true) + expectation.fulfill() + throw error + } + } + + let iterator = stream!.makeAsyncIterator() + let result = try await iterator.next() + XCTAssertEqual(result, 0) + + withUnsafeCurrentTask { task in + task?.cancel() + } + + /// Let the write happen strictly after cancellation, in its own task so signaling doesn't "see" the cancellation. + await Task { await writeContinuations.first(where: { _ in true })!.resume() }.value + + /// The stream can't be marked as cancelled if another read never happens. + let wasCancelled = await iterator.wasCancelled + XCTAssertEqual(wasCancelled, false) + + stream = nil + } + + try? await task.value + readProvider.finish() + + await fulfillment(of: [expectation], timeout: 10) + } +}