Skip to content

Commit 65b7ef2

Browse files
authored
Use the Toolhelp32 API to enumerate loaded Win32 modules. (swiftlang#892)
This PR replaces our call to `EnumProcessModules()` with one to `CreateToolhelp32Snapshot(TH32CS_SNAPMODULE)`. Three reasons: 1. `EnumProcessModules()` requires us to specify a large, fixed-size buffer to contain all the `HMODULE` handles; 2. `EnumProcessModules()` does not own any references to the handles it returns, meaning that a module can be unloaded while we are iterating over them (while `CreateToolhelp32Snapshot()` temporarily bumps the refcounts of the handles it produces); and 3. `CreateToolhelp32Snapshot()` lets us produce a lazy sequence of `HMODULE` values rather than an array, letting us write somewhat Swiftier code that uses it. The overhead of using `CreateToolhelp32Snapshot()` was negligible (below the noise level when measuring). ### Checklist: - [x] Code and documentation should follow the style of the [Style Guide](https://github.com/apple/swift-testing/blob/main/Documentation/StyleGuide.md). - [x] If public symbols are renamed or modified, DocC references should be updated.
1 parent 4ca12ee commit 65b7ef2

File tree

2 files changed

+56
-19
lines changed

2 files changed

+56
-19
lines changed

Sources/Testing/Support/Additions/WinSDKAdditions.swift

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,56 @@ let STATUS_SIGNAL_CAUGHT_BITS = {
5050

5151
return result
5252
}()
53+
54+
// MARK: - HMODULE members
55+
56+
extension HMODULE {
57+
/// A helper type that manages state for ``HMODULE/all``.
58+
private final class _AllState {
59+
/// The toolhelp snapshot.
60+
var snapshot: HANDLE?
61+
62+
/// The module iterator.
63+
var me = MODULEENTRY32W()
64+
65+
deinit {
66+
if let snapshot {
67+
CloseHandle(snapshot)
68+
}
69+
}
70+
}
71+
72+
/// All modules loaded in the current process.
73+
///
74+
/// - Warning: It is possible for one or more modules in this sequence to be
75+
/// unloaded while you are iterating over it. To minimize the risk, do not
76+
/// discard the sequence until iteration is complete. Modules containing
77+
/// Swift code can never be safely unloaded.
78+
static var all: some Sequence<Self> {
79+
sequence(state: _AllState()) { state in
80+
if let snapshot = state.snapshot {
81+
// We have already iterated over the first module. Return the next one.
82+
if Module32NextW(snapshot, &state.me) {
83+
return state.me.hModule
84+
}
85+
} else {
86+
// Create a toolhelp snapshot that lists modules.
87+
guard let snapshot = CreateToolhelp32Snapshot(DWORD(TH32CS_SNAPMODULE), 0) else {
88+
return nil
89+
}
90+
state.snapshot = snapshot
91+
92+
// Initialize the iterator for use by the resulting sequence and return
93+
// the first module.
94+
state.me.dwSize = DWORD(MemoryLayout.stride(ofValue: state.me))
95+
if Module32FirstW(snapshot, &state.me) {
96+
return state.me.hModule
97+
}
98+
}
99+
100+
// Reached the end of the iteration.
101+
return nil
102+
}
103+
}
104+
}
53105
#endif

Sources/Testing/Support/GetSymbol.swift

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,25 +70,10 @@ func symbol(in handle: ImageAddress? = nil, named symbolName: String) -> UnsafeR
7070
}
7171
}
7272

73-
// Find all the modules loaded in the current process. We assume there
74-
// aren't more than 1024 loaded modules (as does Microsoft sample code.)
75-
return withUnsafeTemporaryAllocation(of: HMODULE?.self, capacity: 1024) { hModules in
76-
let byteCount = DWORD(hModules.count * MemoryLayout<HMODULE?>.stride)
77-
var byteCountNeeded: DWORD = 0
78-
guard K32EnumProcessModules(GetCurrentProcess(), hModules.baseAddress!, byteCount, &byteCountNeeded) else {
79-
return nil
80-
}
81-
82-
// Enumerate all modules looking for one containing the given symbol.
83-
let hModuleCount = min(hModules.count, Int(byteCountNeeded) / MemoryLayout<HMODULE?>.stride)
84-
let hModulesEnd = hModules.index(hModules.startIndex, offsetBy: hModuleCount)
85-
for hModule in hModules[..<hModulesEnd] {
86-
if let hModule, let result = GetProcAddress(hModule, symbolName) {
87-
return unsafeBitCast(result, to: UnsafeRawPointer.self)
88-
}
89-
}
90-
return nil
91-
}
73+
return HMODULE.all.lazy
74+
.compactMap { GetProcAddress($0, symbolName) }
75+
.map { unsafeBitCast($0, to: UnsafeRawPointer.self) }
76+
.first
9277
}
9378
#else
9479
#warning("Platform-specific implementation missing: Dynamic loading unavailable")

0 commit comments

Comments
 (0)