Skip to content
291 changes: 265 additions & 26 deletions Sources/TestSupport/Utilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func expectNoChanges<T: BinaryInteger>(_ check: @autoclosure () -> T, by differe
///
/// - Note: `oracle` is also checked for conformance to the
/// laws.
public func checkEquatable<Instances: Collection>(
public func XCTCheckEquatable<Instances: Collection>(
_ instances: Instances,
oracle: (Instances.Index, Instances.Index) -> Bool,
allowBrokenTransitivity: Bool = false,
Expand All @@ -179,7 +179,7 @@ public func checkEquatable<Instances: Collection>(
line: UInt = #line
) where Instances.Element: Equatable {
let indices = Array(instances.indices)
_checkEquatableImpl(
_XCTCheckEquatableImpl(
Array(instances),
oracle: { oracle(indices[$0], indices[$1]) },
allowBrokenTransitivity: allowBrokenTransitivity,
Expand All @@ -188,15 +188,7 @@ public func checkEquatable<Instances: Collection>(
line: line)
}

private class Box<T> {
var value: T

init(_ value: T) {
self.value = value
}
}

internal func _checkEquatableImpl<Instance : Equatable>(
internal func _XCTCheckEquatableImpl<Instance : Equatable>(
_ instances: [Instance],
oracle: (Int, Int) -> Bool,
allowBrokenTransitivity: Bool = false,
Expand Down Expand Up @@ -271,23 +263,14 @@ internal func _checkEquatableImpl<Instance : Equatable>(
}
}

func hash<H: Hashable>(_ value: H, salt: Int? = nil) -> Int {
var hasher = Hasher()
if let salt = salt {
hasher.combine(salt)
}
hasher.combine(value)
return hasher.finalize()
}

public func checkHashable<Instances: Collection>(
public func XCTCheckHashable<Instances: Collection>(
_ instances: Instances,
equalityOracle: (Instances.Index, Instances.Index) -> Bool,
allowIncompleteHashing: Bool = false,
_ message: @autoclosure () -> String = "",
file: StaticString = #filePath, line: UInt = #line
) where Instances.Element: Hashable {
checkHashable(
XCTCheckHashable(
instances,
equalityOracle: equalityOracle,
hashEqualityOracle: equalityOracle,
Expand All @@ -298,7 +281,7 @@ public func checkHashable<Instances: Collection>(
}


public func checkHashable<Instances: Collection>(
public func XCTCheckHashable<Instances: Collection>(
_ instances: Instances,
equalityOracle: (Instances.Index, Instances.Index) -> Bool,
hashEqualityOracle: (Instances.Index, Instances.Index) -> Bool,
Expand All @@ -307,7 +290,7 @@ public func checkHashable<Instances: Collection>(
file: StaticString = #filePath, line: UInt = #line
) where Instances.Element: Hashable {

checkEquatable(
XCTCheckEquatable(
instances,
oracle: equalityOracle,
message(),
Expand Down Expand Up @@ -390,7 +373,7 @@ public func checkHashable<Instances: Collection>(
/// Test that the elements of `groups` consist of instances that satisfy the
/// semantic requirements of `Hashable`, with each group defining a distinct
/// equivalence class under `==`.
public func checkHashableGroups<Groups: Collection>(
public func XCTCheckHashableGroups<Groups: Collection>(
_ groups: Groups,
_ message: @autoclosure () -> String = "",
allowIncompleteHashing: Bool = false,
Expand All @@ -405,7 +388,7 @@ public func checkHashableGroups<Groups: Collection>(
func equalityOracle(_ lhs: Int, _ rhs: Int) -> Bool {
return groupIndices[lhs] == groupIndices[rhs]
}
checkHashable(
XCTCheckHashable(
instances,
equalityOracle: equalityOracle,
hashEqualityOracle: equalityOracle,
Expand Down Expand Up @@ -477,3 +460,259 @@ func testExpectedToFailWithCheck<T>(check: (String) -> Bool, _ test: @escaping
}
}

// MARK: - swift-testing Helpers

import Testing

/// Test that the elements of `instances` satisfy the semantic
/// requirements of `Equatable`, using `oracle` to generate equality
/// expectations from pairs of positions in `instances`.
///
/// - Note: `oracle` is also checked for conformance to the
/// laws.
func checkEquatable<Instances : Collection>(
_ instances: Instances,
oracle: (Instances.Index, Instances.Index) -> Bool,
allowBrokenTransitivity: Bool = false,
_ message: @autoclosure () -> String = "",
sourceLocation: SourceLocation = #_sourceLocation
) where Instances.Element: Equatable {
let indices = Array(instances.indices)
_checkEquatable(
instances,
oracle: { oracle(indices[$0], indices[$1]) },
allowBrokenTransitivity: allowBrokenTransitivity,
message(),
sourceLocation: sourceLocation
)
}

func _checkEquatable<Instances : Collection>(
_ _instances: Instances,
oracle: (Int, Int) -> Bool,
allowBrokenTransitivity: Bool = false,
_ message: @autoclosure () -> String = "",
sourceLocation: SourceLocation = #_sourceLocation
) where Instances.Element: Equatable {
let instances = Array(_instances)

// For each index (which corresponds to an instance being tested) track the
// set of equal instances.
var transitivityScoreboard: [Box<Set<Int>>] =
instances.indices.map { _ in Box([]) }

for i in instances.indices {
let x = instances[i]
#expect(oracle(i, i), "bad oracle: broken reflexivity at index \(i)")

for j in instances.indices {
let y = instances[j]

let predictedXY = oracle(i, j)
#expect(
predictedXY == oracle(j, i),
"bad oracle: broken symmetry between indices \(i), \(j)",
sourceLocation: sourceLocation
)

let isEqualXY = x == y
#expect(
predictedXY == isEqualXY,
"""
\((predictedXY
? "expected equal, found not equal"
: "expected not equal, found equal"))
lhs (at index \(i)): \(String(reflecting: x))
rhs (at index \(j)): \(String(reflecting: y))
""",
sourceLocation: sourceLocation
)

// Not-equal is an inverse of equal.
#expect(
isEqualXY != (x != y),
"""
lhs (at index \(i)): \(String(reflecting: x))
rhs (at index \(j)): \(String(reflecting: y))
""",
sourceLocation: sourceLocation
)

if !allowBrokenTransitivity {
// Check transitivity of the predicate represented by the oracle.
// If we are adding the instance `j` into an equivalence set, check that
// it is equal to every other instance in the set.
if predictedXY && i < j && transitivityScoreboard[i].value.insert(j).inserted {
if transitivityScoreboard[i].value.count == 1 {
transitivityScoreboard[i].value.insert(i)
}
for k in transitivityScoreboard[i].value {
#expect(
oracle(j, k),
"bad oracle: broken transitivity at indices \(i), \(j), \(k)",
sourceLocation: sourceLocation
)
// No need to check equality between actual values, we will check
// them with the checks above.
}
precondition(transitivityScoreboard[j].value.isEmpty)
transitivityScoreboard[j] = transitivityScoreboard[i]
}
}
}
}
}

public func checkHashable<Instances: Collection>(
_ instances: Instances,
equalityOracle: (Instances.Index, Instances.Index) -> Bool,
allowIncompleteHashing: Bool = false,
_ message: @autoclosure () -> String = "",
sourceLocation: SourceLocation = #_sourceLocation
) where Instances.Element: Hashable {
checkHashable(
instances,
equalityOracle: equalityOracle,
hashEqualityOracle: equalityOracle,
allowIncompleteHashing: allowIncompleteHashing,
message(),
sourceLocation: sourceLocation)
}

func checkHashable<Instances: Collection>(
_ instances: Instances,
equalityOracle: (Instances.Index, Instances.Index) -> Bool,
hashEqualityOracle: (Instances.Index, Instances.Index) -> Bool,
allowIncompleteHashing: Bool = false,
_ message: @autoclosure () -> String = "",
sourceLocation: SourceLocation = #_sourceLocation
) where Instances.Element: Hashable {
checkEquatable(
instances,
oracle: equalityOracle,
message(),
sourceLocation: sourceLocation
)

for i in instances.indices {
let x = instances[i]
for j in instances.indices {
let y = instances[j]
let predicted = hashEqualityOracle(i, j)
#expect(
predicted == hashEqualityOracle(j, i),
"bad hash oracle: broken symmetry between indices \(i), \(j)",
sourceLocation: sourceLocation
)
if x == y {
#expect(
predicted,
"""
bad hash oracle: equality must imply hash equality
lhs (at index \(i)): \(x)
rhs (at index \(j)): \(y)
""",
sourceLocation: sourceLocation
)
}
if predicted {
#expect(
hash(x) == hash(y),
"""
hash(into:) expected to match, found to differ
lhs (at index \(i)): \(x)
rhs (at index \(j)): \(y)
""",
sourceLocation: sourceLocation
)
#expect(
x.hashValue == y.hashValue,
"""
hashValue expected to match, found to differ
lhs (at index \(i)): \(x)
rhs (at index \(j)): \(y)
""",
sourceLocation: sourceLocation
)
#expect(
x._rawHashValue(seed: 0) == y._rawHashValue(seed: 0),
"""
_rawHashValue(seed:) expected to match, found to differ
lhs (at index \(i)): \(x)
rhs (at index \(j)): \(y)
""",
sourceLocation: sourceLocation
)
} else if !allowIncompleteHashing {
// Try a few different seeds; at least one of them should discriminate
// between the hashes. It is extremely unlikely this check will fail
// all ten attempts, unless the type's hash encoding is not unique,
// or unless the hash equality oracle is wrong.
#expect(
(0..<10).contains { hash(x, salt: $0) != hash(y, salt: $0) },
"""
hash(into:) expected to differ, found to match
lhs (at index \(i)): \(x)
rhs (at index \(j)): \(y)
""",
sourceLocation: sourceLocation
)
#expect(
(0..<10).contains { i in
x._rawHashValue(seed: i) != y._rawHashValue(seed: i)
},
"""
_rawHashValue(seed:) expected to differ, found to match
lhs (at index \(i)): \(x)
rhs (at index \(j)): \(y)
""",
sourceLocation: sourceLocation
)
}
}
}
}

/// Test that the elements of `groups` consist of instances that satisfy the
/// semantic requirements of `Hashable`, with each group defining a distinct
/// equivalence class under `==`.
public func checkHashableGroups<Groups: Collection>(
_ groups: Groups,
_ message: @autoclosure () -> String = "",
allowIncompleteHashing: Bool = false,
sourceLocation: SourceLocation = #_sourceLocation
) where Groups.Element: Collection, Groups.Element.Element: Hashable {
let instances = groups.flatMap { $0 }
// groupIndices[i] is the index of the element in groups that contains
// instances[i].
let groupIndices =
zip(0..., groups).flatMap { i, group in group.map { _ in i } }
func equalityOracle(_ lhs: Int, _ rhs: Int) -> Bool {
return groupIndices[lhs] == groupIndices[rhs]
}
checkHashable(
instances,
equalityOracle: equalityOracle,
hashEqualityOracle: equalityOracle,
allowIncompleteHashing: allowIncompleteHashing,
sourceLocation: sourceLocation)
}

// MARK: - Private Types

private class Box<T> {
var value: T

init(_ value: T) {
self.value = value
}
}

private func hash<H: Hashable>(_ value: H, salt: Int? = nil) -> Int {
var hasher = Hasher()
if let salt = salt {
hasher.combine(salt)
}
hasher.combine(value)
return hasher.finalize()
}
Loading
Loading