Skip to content
26 changes: 16 additions & 10 deletions Sources/Testing/Traits/ConditionTrait.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
/// - ``Trait/disabled(if:_:sourceLocation:)``
/// - ``Trait/disabled(_:sourceLocation:_:)``
public struct ConditionTrait: TestTrait, SuiteTrait {
/// The result of evaluating the condition.
public typealias Evaluation = (Bool, comment: Comment?)

/// An enumeration describing the kinds of conditions that can be represented
/// by an instance of this type.
enum Kind: Sendable {
Expand All @@ -30,7 +33,7 @@ public struct ConditionTrait: TestTrait, SuiteTrait {
/// `false` and a comment is also returned, it is used in place of the
/// value of the associated trait's ``ConditionTrait/comment`` property.
/// If this function returns `true`, the returned comment is ignored.
case conditional(_ body: @Sendable () async throws -> (Bool, comment: Comment?))
case conditional(_ body: @Sendable () async throws -> Evaluation)

/// Create an instance of this type associated with a trait that is
/// conditional on the result of calling a function.
Expand All @@ -41,7 +44,7 @@ public struct ConditionTrait: TestTrait, SuiteTrait {
///
/// - Returns: An instance of this type.
static func conditional(_ body: @escaping @Sendable () async throws -> Bool) -> Self {
conditional { () -> (Bool, comment: Comment?) in
conditional { () -> Evaluation in
return (try await body(), nil)
}
}
Expand Down Expand Up @@ -79,19 +82,22 @@ public struct ConditionTrait: TestTrait, SuiteTrait {

/// The source location where this trait was specified.
public var sourceLocation: SourceLocation

public func prepare(for test: Test) async throws {
let result: Bool
var commentOverride: Comment?


/// Returns the result of evaluating the condition.
@_spi(Experimental)
public func evaluate() async throws -> Evaluation {
switch kind {
case let .conditional(condition):
(result, commentOverride) = try await condition()
try await condition()
case let .unconditional(unconditionalValue):
result = unconditionalValue
(unconditionalValue, nil)
}
}

public func prepare(for test: Test) async throws {
let (isEnabled, commentOverride) = try await evaluate()

if !result {
if !isEnabled {
// We don't need to consider including a backtrace here because it will
// primarily contain frames in the testing library, not user code. If an
// error was thrown by a condition evaluated above, the caller _should_
Expand Down