Skip to content

Commit e45efd7

Browse files
authored
Refactor type enumeration loop. (#388)
1 parent db5a646 commit e45efd7

File tree

4 files changed

+67
-75
lines changed

4 files changed

+67
-75
lines changed

Sources/Testing/ExitTests/ExitTest.swift

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -79,33 +79,20 @@ extension ExitTest {
7979
/// - Returns: The specified exit test function, or `nil` if no such exit test
8080
/// could be found.
8181
public static func find(at sourceLocation: SourceLocation) -> Self? {
82-
struct Context {
83-
var sourceLocation: SourceLocation
84-
var result: ExitTest?
85-
}
86-
var context = Context(sourceLocation: sourceLocation)
87-
withUnsafeMutablePointer(to: &context) { context in
88-
swt_enumerateTypes(context) { type, context in
89-
let context = context!.assumingMemoryBound(to: (Context).self)
90-
if let type = unsafeBitCast(type, to: Any.Type.self) as? any __ExitTestContainer.Type,
91-
type.__sourceLocation == context.pointee.sourceLocation {
92-
context.pointee.result = ExitTest(
93-
expectedExitCondition: type.__expectedExitCondition,
94-
body: type.__body,
95-
sourceLocation: type.__sourceLocation
96-
)
97-
return false
98-
}
99-
return true
100-
} withNamesMatching: { typeName, _ in
101-
// strstr() lets us avoid copying either string before comparing.
102-
Self._exitTestContainerTypeNameMagic.withCString { testContainerTypeNameMagic in
103-
nil != strstr(typeName, testContainerTypeNameMagic)
104-
}
82+
var result: Self?
83+
84+
enumerateTypes(withNamesContaining: _exitTestContainerTypeNameMagic) { type, stop in
85+
if let type = type as? any __ExitTestContainer.Type, type.__sourceLocation == sourceLocation {
86+
result = ExitTest(
87+
expectedExitCondition: type.__expectedExitCondition,
88+
body: type.__body,
89+
sourceLocation: type.__sourceLocation
90+
)
91+
stop = true
10592
}
10693
}
10794

108-
return context.result
95+
return result
10996
}
11097
}
11198

Sources/Testing/Test+Discovery.swift

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,12 @@ extension Test {
4848
private static var _all: some Sequence<Self> {
4949
get async {
5050
await withTaskGroup(of: [Self].self) { taskGroup in
51-
swt_enumerateTypes(&taskGroup) { type, context in
52-
if let type = unsafeBitCast(type, to: Any.Type.self) as? any __TestContainer.Type {
53-
let taskGroup = context!.assumingMemoryBound(to: TaskGroup<[Self]>.self)
54-
taskGroup.pointee.addTask {
51+
enumerateTypes(withNamesContaining: _testContainerTypeNameMagic) { type, _ in
52+
if let type = type as? any __TestContainer.Type {
53+
taskGroup.addTask {
5554
await type.__tests
5655
}
5756
}
58-
return true
59-
} withNamesMatching: { typeName, _ in
60-
// strstr() lets us avoid copying either string before comparing.
61-
Self._testContainerTypeNameMagic.withCString { testContainerTypeNameMagic in
62-
nil != strstr(typeName, testContainerTypeNameMagic)
63-
}
6457
}
6558

6659
return await taskGroup.reduce(into: [], +=)
@@ -115,3 +108,34 @@ extension Test {
115108
return tests.count - originalCount
116109
}
117110
}
111+
112+
// MARK: -
113+
114+
/// The type of callback called by ``enumerateTypes(withNamesContaining:_:)``.
115+
///
116+
/// - Parameters:
117+
/// - type: A Swift type.
118+
/// - stop: An `inout` boolean variable indicating whether type enumeration
119+
/// should stop after the function returns. Set `stop` to `true` to stop
120+
/// type enumeration.
121+
typealias TypeEnumerator = (_ type: Any.Type, _ stop: inout Bool) -> Void
122+
123+
/// Enumerate all types known to Swift found in the current process whose names
124+
/// contain a given substring.
125+
///
126+
/// - Parameters:
127+
/// - nameSubstring: A string which the names of matching classes all contain.
128+
/// - body: A function to invoke, once per matching type.
129+
func enumerateTypes(withNamesContaining nameSubstring: String, _ typeEnumerator: TypeEnumerator) {
130+
withoutActuallyEscaping(typeEnumerator) { typeEnumerator in
131+
withUnsafePointer(to: typeEnumerator) { context in
132+
swt_enumerateTypes(withNamesContaining: nameSubstring, .init(mutating: context)) { type, stop, context in
133+
let typeEnumerator = context!.load(as: TypeEnumerator.self)
134+
let type = unsafeBitCast(type, to: Any.Type.self)
135+
var stop2 = false
136+
typeEnumerator(type, &stop2)
137+
stop.pointee = stop2
138+
}
139+
}
140+
}
141+
}

Sources/TestingInternals/Discovery.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "Discovery.h"
1212

1313
#include <atomic>
14+
#include <cstring>
1415
#include <iterator>
1516
#include <type_traits>
1617
#include <vector>
@@ -328,13 +329,13 @@ static void enumerateTypeMetadataSections(const SectionEnumerator& body) {}
328329

329330
#pragma mark -
330331

331-
void swt_enumerateTypes(void *context, SWTTypeEnumerator body, SWTTypeNameFilter nameFilter) {
332+
void swt_enumerateTypesWithNamesContaining(const char *nameSubstring, void *context, SWTTypeEnumerator body) {
332333
enumerateTypeMetadataSections([=] (const void *section, size_t size) {
333334
auto records = reinterpret_cast<const SWTTypeMetadataRecord *>(section);
334335
size_t recordCount = size / sizeof(SWTTypeMetadataRecord);
335336

336-
bool keepGoing = true;
337-
for (size_t i = 0; i < recordCount && keepGoing; i++) {
337+
bool stop = false;
338+
for (size_t i = 0; i < recordCount && !stop; i++) {
338339
const auto& record = records[i];
339340

340341
auto contextDescriptor = record.getContextDescriptor();
@@ -348,19 +349,16 @@ void swt_enumerateTypes(void *context, SWTTypeEnumerator body, SWTTypeNameFilter
348349
continue;
349350
}
350351

351-
// If the caller supplied a name filtering function, check that the type's
352-
// name passes. This will be more expensive than the checks above, but
353-
// should be cheaper than realizing the metadata.
354-
if (nameFilter) {
355-
const char *typeName = contextDescriptor->getName();
356-
bool nameOK = typeName && (* nameFilter)(typeName, context);
357-
if (!nameOK) {
358-
continue;
359-
}
352+
// Check that the type's name passes. This will be more expensive than the
353+
// checks above, but should be cheaper than realizing the metadata.
354+
const char *typeName = contextDescriptor->getName();
355+
bool nameOK = typeName && nullptr != std::strstr(typeName, nameSubstring);
356+
if (!nameOK) {
357+
continue;
360358
}
361359

362360
if (void *typeMetadata = contextDescriptor->getMetadata()) {
363-
keepGoing = body(typeMetadata, context);
361+
body(typeMetadata, &stop, context);
364362
}
365363
}
366364
});

Sources/TestingInternals/include/Discovery.h

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,45 +17,28 @@
1717

1818
SWT_ASSUME_NONNULL_BEGIN
1919

20-
/// The type of callback that is called by `swt_enumerateTypes()`.
20+
/// The type of callback called by `swt_enumerateTypes()`.
2121
///
2222
/// - Parameters:
2323
/// - typeMetadata: A type metadata pointer that can be bitcast to `Any.Type`.
24+
/// - stop: A pointer to a boolean variable indicating whether type
25+
/// enumeration should stop after the function returns. Set `*stop` to
26+
/// `true` to stop type enumeration.
2427
/// - context: An arbitrary pointer passed by the caller to
2528
/// `swt_enumerateTypes()`.
26-
///
27-
/// - Returns: Whether or not to continue enumeration.
28-
typedef bool (* SWTTypeEnumerator)(void *typeMetadata, void *_Null_unspecified context);
29-
30-
/// The type name filter that is called by `swt_enumerateTypes()`.
31-
///
32-
/// - Parameters:
33-
/// - typeName: The name of the type being considered, as a C string.
34-
/// - context: An arbitrary pointer passed by the caller to
35-
/// `swt_enumerateTypes()`.
36-
///
37-
/// - Returns: Whether or not the type named by `typeName` should be passed to
38-
/// the corresponding enumerator function.
39-
typedef bool (* SWTTypeNameFilter)(const char *typeName, void *_Null_unspecified context);
29+
typedef void (* SWTTypeEnumerator)(void *typeMetadata, bool *stop, void *_Null_unspecified context);
4030

4131
/// Enumerate all types known to Swift found in the current process.
4232
///
4333
/// - Parameters:
44-
/// - nameFilter: If not `nullptr`, a filtering function that checks if a type
45-
/// name is valid before realizing the type.
46-
/// - body: A function to invoke. `context` is passed to it along with a
47-
/// type metadata pointer (which can be bitcast to `Any.Type`.)
34+
/// - nameSubstring: A string which the names of matching classes all contain.
4835
/// - context: An arbitrary pointer to pass to `body`.
49-
///
50-
/// This function may enumerate the same type more than once (for instance, if
51-
/// it is present in an image's metadata table multiple times, or if it is an
52-
/// Objective-C class implemented in Swift.) Callers are responsible for
53-
/// deduping type metadata pointers passed to `body`.
54-
SWT_EXTERN void swt_enumerateTypes(
36+
/// - body: A function to invoke, once per matching type.
37+
SWT_EXTERN void swt_enumerateTypesWithNamesContaining(
38+
const char *nameSubstring,
5539
void *_Null_unspecified context,
56-
SWTTypeEnumerator body,
57-
SWTTypeNameFilter _Nullable nameFilter
58-
) SWT_SWIFT_NAME(swt_enumerateTypes(_:_:withNamesMatching:));
40+
SWTTypeEnumerator body
41+
) SWT_SWIFT_NAME(swt_enumerateTypes(withNamesContaining:_:_:));
5942

6043
SWT_ASSUME_NONNULL_END
6144

0 commit comments

Comments
 (0)