Skip to content

Make setting HTTPBody.iteratorCreated thread-safe #95

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 2 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions Sources/OpenAPIRuntime/Interface/HTTPBody.swift
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,6 @@ public final class HTTPBody: @unchecked Sendable {
return locked_iteratorCreated
}

/// Verifying that creating another iterator is allowed based on
/// the values of `iterationBehavior` and `locked_iteratorCreated`.
/// - Throws: If another iterator is not allowed to be created.
private func checkIfCanCreateIterator() throws {
lock.lock()
defer { lock.unlock() }
guard iterationBehavior == .single else { return }
if locked_iteratorCreated { throw TooManyIterationsError() }
}

/// Tries to mark an iterator as created, verifying that it is allowed
/// based on the values of `iterationBehavior` and `locked_iteratorCreated`.
/// - Throws: If another iterator is not allowed to be created.
Expand Down Expand Up @@ -341,10 +331,12 @@ extension HTTPBody: AsyncSequence {
/// Creates and returns an asynchronous iterator
///
/// - Returns: An asynchronous iterator for byte chunks.
/// - Note: The returned sequence throws an error if no further iterations are allowed. See ``IterationBehavior``.
public func makeAsyncIterator() -> AsyncIterator {
// The crash on error is intentional here.
try! tryToMarkIteratorCreated()
return .init(sequence.makeAsyncIterator())
do {
try tryToMarkIteratorCreated()
return .init(sequence.makeAsyncIterator())
} catch { return .init(throwing: error) }
}
}

Expand Down Expand Up @@ -381,10 +373,6 @@ extension HTTPBody {
/// than `maxBytes`.
/// - Returns: A byte chunk containing all the accumulated bytes.
fileprivate func collect(upTo maxBytes: Int) async throws -> ByteChunk {

// Check that we're allowed to iterate again.
try checkIfCanCreateIterator()

// If the length is known, verify it's within the limit.
if case .known(let knownBytes) = length {
guard knownBytes <= maxBytes else { throw TooManyBytesError(maxBytes: maxBytes) }
Expand Down Expand Up @@ -563,6 +551,9 @@ extension HTTPBody {
var iterator = iterator
self.produceNext = { try await iterator.next() }
}
/// Creates an iterator throwing the given error when iterated.
/// - Parameter error: The error to throw on iteration.
fileprivate init(throwing error: any Error) { self.produceNext = { throw error } }

/// Advances the iterator to the next element and returns it asynchronously.
///
Expand Down
22 changes: 9 additions & 13 deletions Sources/OpenAPIRuntime/Multipart/MultipartPublicTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,6 @@ public final class MultipartBody<Part: Sendable>: @unchecked Sendable {
var errorDescription: String? { description }
}

/// Verifying that creating another iterator is allowed based on the values of `iterationBehavior`
/// and `locked_iteratorCreated`.
/// - Throws: If another iterator is not allowed to be created.
internal func checkIfCanCreateIterator() throws {
lock.lock()
defer { lock.unlock() }
guard iterationBehavior == .single else { return }
if locked_iteratorCreated { throw TooManyIterationsError() }
}

/// Tries to mark an iterator as created, verifying that it is allowed based on the values
/// of `iterationBehavior` and `locked_iteratorCreated`.
/// - Throws: If another iterator is not allowed to be created.
Expand Down Expand Up @@ -331,10 +321,12 @@ extension MultipartBody: AsyncSequence {
/// Creates and returns an asynchronous iterator
///
/// - Returns: An asynchronous iterator for parts.
/// - Note: The returned sequence throws an error if no further iterations are allowed. See ``IterationBehavior``.
public func makeAsyncIterator() -> AsyncIterator {
// The crash on error is intentional here.
try! tryToMarkIteratorCreated()
return .init(sequence.makeAsyncIterator())
do {
try tryToMarkIteratorCreated()
return .init(sequence.makeAsyncIterator())
} catch { return .init(throwing: error) }
}
}

Expand All @@ -355,6 +347,10 @@ extension MultipartBody {
self.produceNext = { try await iterator.next() }
}

/// Creates an iterator throwing the given error when iterated.
/// - Parameter error: The error to throw on iteration.
fileprivate init(throwing error: any Error) { self.produceNext = { throw error } }

/// Advances the iterator to the next element and returns it asynchronously.
///
/// - Returns: The next element in the sequence, or `nil` if there are no more elements.
Expand Down
5 changes: 5 additions & 0 deletions Tests/OpenAPIRuntimeTests/Interface/Test_HTTPBody.swift
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ final class Test_Body: Test_Runtime {
_ = try await String(collecting: body, upTo: .max)
XCTFail("Expected an error to be thrown")
} catch {}

do {
for try await _ in body {}
XCTFail("Expected an error to be thrown")
} catch {}
}

func testIterationBehavior_multiple() async throws {
Expand Down
49 changes: 49 additions & 0 deletions Tests/OpenAPIRuntimeTests/Interface/Test_MultipartBody.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftOpenAPIGenerator open source project
//
// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import XCTest
@_spi(Generated) @testable import OpenAPIRuntime
import Foundation

final class Test_MultipartBody: XCTestCase {

func testIterationBehavior_single() async throws {
let sourceSequence = (0..<Int.random(in: 2..<10)).map { _ in UUID().uuidString }
let body = MultipartBody(sourceSequence, iterationBehavior: .single)

XCTAssertFalse(body.testing_iteratorCreated)

let iterated = try await body.reduce("") { $0 + $1 }
XCTAssertEqual(iterated, sourceSequence.joined())

XCTAssertTrue(body.testing_iteratorCreated)

do {
for try await _ in body {}
XCTFail("Expected an error to be thrown")
} catch {}
}

func testIterationBehavior_multiple() async throws {
let sourceSequence = (0..<Int.random(in: 2..<10)).map { _ in UUID().uuidString }
let body = MultipartBody(sourceSequence, iterationBehavior: .multiple)

XCTAssertFalse(body.testing_iteratorCreated)
for _ in 0..<2 {
let iterated = try await body.reduce("") { $0 + $1 }
XCTAssertEqual(iterated, sourceSequence.joined())
XCTAssertTrue(body.testing_iteratorCreated)
}
}

}