Skip to content

Commit c8c37be

Browse files
committed
Do __cmp() typechecking at runtime instead of compile time
1 parent be162e5 commit c8c37be

File tree

1 file changed

+67
-90
lines changed

1 file changed

+67
-90
lines changed

Sources/Testing/Expectations/ExpectationContext.swift

Lines changed: 67 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,77 @@ extension __ExpectationContext {
253253
return resultComponents.joined(separator: ", ")
254254
}
255255

256+
/// Capture the difference between `lhs` and `rhs` at runtime.
257+
///
258+
/// - Parameters:
259+
/// - lhs: The left-hand operand.
260+
/// - rhs: The right-hand operand.
261+
/// - opID: A value that uniquely identifies the binary operation expression
262+
/// of which `lhs` and `rhs` are operands.
263+
///
264+
/// This function performs additional type checking of `lhs` and `rhs` at
265+
/// runtime. If we instead overload the caller (`__cmp()`) it puts extra
266+
/// compile-time pressure on the type checker that we don't want.
267+
@usableFromInline func captureDifferences<T, U>(_ lhs: T, _ rhs: U, _ opID: __ExpressionID) {
268+
#if !hasFeature(Embedded) // no existentials
269+
if let lhs = lhs as? any StringProtocol {
270+
func open<V>(_ lhs: V, _ rhs: U) where V: StringProtocol {
271+
guard let rhs = rhs as? V else {
272+
return
273+
}
274+
differences[opID] = {
275+
// Compare strings by line, not by character.
276+
let lhsLines = String(lhs).split(whereSeparator: \.isNewline)
277+
let rhsLines = String(rhs).split(whereSeparator: \.isNewline)
278+
279+
if lhsLines.count == 1 && rhsLines.count == 1 {
280+
// There are no newlines in either string, so there's no meaningful
281+
// per-line difference. Bail.
282+
return nil
283+
}
284+
285+
let diff = lhsLines.difference(from: rhsLines)
286+
if diff.isEmpty {
287+
// The strings must have compared on a per-character basis, or this
288+
// operator doesn't behave the way we expected. Bail.
289+
return nil
290+
}
291+
292+
return CollectionDifference<Any>(diff)
293+
}
294+
}
295+
open(lhs, rhs)
296+
} else if lhs is any RangeExpression {
297+
// Do _not_ perform a diffing operation on `lhs` and `rhs`. Range
298+
// expressions are not usefully diffable the way other kinds of
299+
// collections are. SEE: https://github.com/swiftlang/swift-testing/issues/639
300+
} else if let lhs = lhs as? any BidirectionalCollection {
301+
func open<V>(_ lhs: V, _ rhs: U) where V: BidirectionalCollection {
302+
guard let rhs = rhs as? V,
303+
let elementType = V.Element.self as? any Equatable.Type else {
304+
return
305+
}
306+
differences[opID] = {
307+
func open<E>(_: E.Type) -> CollectionDifference<Any> where E: Equatable {
308+
let lhs: some BidirectionalCollection<E> = lhs.lazy.map { unsafeBitCast($0, to: E.self) }
309+
let rhs: some BidirectionalCollection<E> = rhs.lazy.map { unsafeBitCast($0, to: E.self) }
310+
return CollectionDifference<Any>(lhs.difference(from: rhs))
311+
}
312+
return open(elementType)
313+
}
314+
}
315+
open(lhs, rhs)
316+
}
317+
#endif
318+
}
319+
256320
/// Compare two values using `==` or `!=`.
257321
///
258322
/// - Parameters:
259323
/// - lhs: The left-hand operand.
260324
/// - lhsID: A value that uniquely identifies the expression represented by
261325
/// `lhs` in the context of the expectation currently being evaluated.
262-
/// - rhs: The left-hand operand.
326+
/// - rhs: The right-hand operand.
263327
/// - rhsID: A value that uniquely identifies the expression represented by
264328
/// `rhs` in the context of the expectation currently being evaluated.
265329
/// - op: A function that performs an operation on `lhs` and `rhs`.
@@ -283,97 +347,10 @@ extension __ExpectationContext {
283347
) rethrows -> Bool {
284348
let lhs = copy lhs
285349
let rhs = copy rhs
286-
return try captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
287-
}
288-
289-
/// Compare two bidirectional collections using `==` or `!=`.
290-
///
291-
/// This overload of `__cmp()` performs a diffing operation on `lhs` and `rhs`
292-
/// if the result of `op(lhs, rhs)` is `false`.
293-
///
294-
/// - Warning: This function is used to implement the `#expect()` and
295-
/// `#require()` macros. Do not call it directly.
296-
public func __cmp<C>(
297-
_ op: (C, C) -> Bool,
298-
_ opID: __ExpressionID,
299-
_ lhs: borrowing C,
300-
_ lhsID: __ExpressionID,
301-
_ rhs: borrowing C,
302-
_ rhsID: __ExpressionID
303-
) -> Bool where C: BidirectionalCollection, C.Element: Equatable {
304-
let lhs = copy lhs
305-
let rhs = copy rhs
306-
let result = captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
307-
308-
if !result {
309-
differences[opID] = { CollectionDifference<Any>(lhs.difference(from: rhs)) }
310-
}
311-
312-
return result
313-
}
314-
315-
/// Compare two range expressions using `==` or `!=`.
316-
///
317-
/// This overload of `__cmp()` does _not_ perform a diffing operation on `lhs`
318-
/// and `rhs`. Range expressions are not usefully diffable the way other kinds
319-
/// of collections are. ([#639](https://github.com/swiftlang/swift-testing/issues/639))
320-
///
321-
/// - Warning: This function is used to implement the `#expect()` and
322-
/// `#require()` macros. Do not call it directly.
323-
@inlinable public func __cmp<R>(
324-
_ op: (R, R) -> Bool,
325-
_ opID: __ExpressionID,
326-
_ lhs: borrowing R,
327-
_ lhsID: __ExpressionID,
328-
_ rhs: borrowing R,
329-
_ rhsID: __ExpressionID
330-
) -> Bool where R: RangeExpression & BidirectionalCollection, R.Element: Equatable {
331-
let lhs = copy lhs
332-
let rhs = copy rhs
333-
return captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
334-
}
335-
336-
/// Compare two strings using `==` or `!=`.
337-
///
338-
/// This overload of `__cmp()` performs a diffing operation on `lhs` and `rhs`
339-
/// if the result of `op(lhs, rhs)` is `false`, but does so by _line_, not by
340-
/// _character_.
341-
///
342-
/// - Warning: This function is used to implement the `#expect()` and
343-
/// `#require()` macros. Do not call it directly.
344-
public func __cmp<S>(
345-
_ op: (S, S) -> Bool,
346-
_ opID: __ExpressionID,
347-
_ lhs: borrowing S,
348-
_ lhsID: __ExpressionID,
349-
_ rhs: borrowing S,
350-
_ rhsID: __ExpressionID
351-
) -> Bool where S: StringProtocol {
352-
let lhs = copy lhs
353-
let rhs = copy rhs
354-
let result = captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
350+
let result = try captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
355351

356352
if !result {
357-
differences[opID] = {
358-
// Compare strings by line, not by character.
359-
let lhsLines = String(lhs).split(whereSeparator: \.isNewline)
360-
let rhsLines = String(rhs).split(whereSeparator: \.isNewline)
361-
362-
if lhsLines.count == 1 && rhsLines.count == 1 {
363-
// There are no newlines in either string, so there's no meaningful
364-
// per-line difference. Bail.
365-
return nil
366-
}
367-
368-
let diff = lhsLines.difference(from: rhsLines)
369-
if diff.isEmpty {
370-
// The strings must have compared on a per-character basis, or this
371-
// operator doesn't behave the way we expected. Bail.
372-
return nil
373-
}
374-
375-
return CollectionDifference<Any>(diff)
376-
}
353+
captureDifferences(lhs, rhs, opID)
377354
}
378355

379356
return result

0 commit comments

Comments
 (0)