Skip to content

Commit 9ca4500

Browse files
feat: software breakpoint hook
1 parent 58aab31 commit 9ca4500

File tree

3 files changed

+200
-29
lines changed

3 files changed

+200
-29
lines changed

include/blook/hook.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include "utils.h"
77
#include "zasm/x86/assembler.hpp"
88
#include <optional>
9+
#include <map>
10+
#include <Windows.h>
911

1012
namespace blook {
1113

@@ -96,12 +98,10 @@ class VEHHookManager {
9698

9799
struct SoftwareBreakpoint {
98100
void *address = nullptr;
99-
std::vector<uint8_t> original_bytes;
100101
};
101102

102103
struct PagefaultBreakpoint {
103104
void *address = nullptr;
104-
int32_t origin_protection = 0;
105105
};
106106

107107
struct HardwareBreakpointInformation {
@@ -110,6 +110,18 @@ class VEHHookManager {
110110
Trampoline trampoline;
111111
};
112112

113+
struct SoftwareBreakpointInformation {
114+
SoftwareBreakpoint bp;
115+
BreakpointCallback callback;
116+
std::vector<uint8_t> original_bytes;
117+
};
118+
119+
struct PagefaultBreakpointInformation {
120+
PagefaultBreakpoint bp;
121+
BreakpointCallback callback;
122+
int32_t origin_protection = 0;
123+
};
124+
113125
static VEHHookManager &instance() {
114126
static VEHHookManager instance;
115127
return instance;
@@ -119,8 +131,16 @@ class VEHHookManager {
119131
short dr_index = -1;
120132
};
121133

134+
struct SoftwareBreakpointHandler {
135+
void *address = nullptr;
136+
};
137+
138+
struct PagefaultBreakpointHandler {
139+
void *address = nullptr;
140+
};
141+
122142
using VEHHookHandler =
123-
std::variant<std::monostate, HardwareBreakpointHandler>;
143+
std::variant<std::monostate, HardwareBreakpointHandler, SoftwareBreakpointHandler, PagefaultBreakpointHandler>;
124144

125145
VEHHookHandler add_breakpoint(HardwareBreakpoint bp,
126146
BreakpointCallback callback);
@@ -131,6 +151,10 @@ class VEHHookManager {
131151
void remove_breakpoint(const VEHHookHandler &handler);
132152

133153
std::array<std::optional<HardwareBreakpointInformation>, 4> hw_breakpoints;
154+
std::map<void *, SoftwareBreakpointInformation> sw_breakpoints;
155+
std::map<void *, PagefaultBreakpointInformation> pf_breakpoints;
156+
std::map<DWORD, void *> thread_bp_in_progress;
157+
134158
void sync_hw_breakpoints();
135159
private:
136160

src/hook.cpp

Lines changed: 127 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -121,34 +121,77 @@ Trampoline Trampoline::make(Pointer pCode, size_t minByteSize,
121121
}
122122

123123
static LONG blook_VectoredExceptionHandler(_EXCEPTION_POINTERS *ExceptionInfo) {
124-
auto &manager = blook::VEHHookManager::instance();
125-
auto code = ExceptionInfo->ExceptionRecord->ExceptionCode;
126-
auto address = ExceptionInfo->ExceptionRecord->ExceptionAddress;
127-
if (code == EXCEPTION_SINGLE_STEP) {
128-
for (const auto &bp : manager.hw_breakpoints) {
129-
if (bp.has_value() && bp->bp.address == address) {
130-
if (bp->callback) {
131-
size_t origRip = ExceptionInfo->ContextRecord->Rip;
132-
VEHHookManager::VEHHookContext ctx(ExceptionInfo);
133-
bp->callback(ctx);
134-
if (ExceptionInfo->ContextRecord->Rip == origRip) {
135-
// If the callback didn't change the RIP, jump to trampoline
136-
ExceptionInfo->ContextRecord->Rip =
137-
(size_t)bp->trampoline.pTrampoline.data();
138-
}
124+
auto &manager = blook::VEHHookManager::instance();
125+
auto code = ExceptionInfo->ExceptionRecord->ExceptionCode;
126+
auto address = ExceptionInfo->ExceptionRecord->ExceptionAddress;
127+
auto thread_id = GetCurrentThreadId();
128+
129+
if (code == EXCEPTION_SINGLE_STEP) {
130+
// Check if this is a re-enable step for a SW/Page-Guard BP
131+
if (manager.thread_bp_in_progress.contains(thread_id)) {
132+
void *bp_address = manager.thread_bp_in_progress.at(thread_id);
133+
134+
if (manager.sw_breakpoints.contains(bp_address)) {
135+
uint8_t int3 = 0xCC;
136+
Pointer(bp_address).write(nullptr, std::span(&int3, 1));
137+
} else if (manager.pf_breakpoints.contains(bp_address)) {
138+
auto &bp = manager.pf_breakpoints.at(bp_address);
139+
DWORD old_protect;
140+
VirtualProtect(bp_address, 1, bp.origin_protection | PAGE_GUARD, &old_protect);
141+
}
142+
manager.thread_bp_in_progress.erase(thread_id);
143+
return EXCEPTION_CONTINUE_EXECUTION;
139144
}
140-
}
145+
146+
// If not a re-enable step, check for a hardware breakpoint
147+
for (auto &bp_opt : manager.hw_breakpoints) {
148+
if (bp_opt.has_value() && bp_opt->bp.address == address) {
149+
if (bp_opt->callback) {
150+
size_t origRip = ExceptionInfo->ContextRecord->Rip;
151+
VEHHookManager::VEHHookContext ctx(ExceptionInfo);
152+
bp_opt->callback(ctx);
153+
if (ExceptionInfo->ContextRecord->Rip == origRip) {
154+
ExceptionInfo->ContextRecord->Rip =
155+
(size_t)bp_opt->trampoline.pTrampoline.data();
156+
}
157+
}
158+
return EXCEPTION_CONTINUE_EXECUTION;
159+
}
160+
}
161+
162+
} else if (code == EXCEPTION_BREAKPOINT) {
163+
if (manager.sw_breakpoints.contains(address)) {
164+
auto &bp = manager.sw_breakpoints.at(address);
165+
Pointer(address).write(nullptr, bp.original_bytes);
166+
ExceptionInfo->ContextRecord->Rip = (DWORD_PTR)address;
167+
168+
if (bp.callback) {
169+
VEHHookManager::VEHHookContext ctx(ExceptionInfo);
170+
bp.callback(ctx);
171+
}
172+
173+
ExceptionInfo->ContextRecord->EFlags |= (1 << 8);
174+
manager.thread_bp_in_progress[thread_id] = address;
175+
return EXCEPTION_CONTINUE_EXECUTION;
176+
}
177+
} else if (code == EXCEPTION_GUARD_PAGE) {
178+
if (manager.pf_breakpoints.contains(address)) {
179+
auto &bp = manager.pf_breakpoints.at(address);
180+
if (bp.callback) {
181+
VEHHookManager::VEHHookContext ctx(ExceptionInfo);
182+
bp.callback(ctx);
183+
}
184+
manager.thread_bp_in_progress[thread_id] = address;
185+
ExceptionInfo->ContextRecord->EFlags |= (1 << 8);
186+
return EXCEPTION_CONTINUE_EXECUTION;
187+
}
188+
return EXCEPTION_CONTINUE_SEARCH;
141189
}
142-
return EXCEPTION_CONTINUE_EXECUTION;
143-
} else if (code == EXCEPTION_BREAKPOINT) {
144-
// Handle software breakpoints here if needed
145-
} else if (code == EXCEPTION_ACCESS_VIOLATION) {
146-
// Handle pagefault breakpoints here if needed
147-
}
148190

149-
return EXCEPTION_CONTINUE_SEARCH;
191+
return EXCEPTION_CONTINUE_SEARCH;
150192
}
151193

194+
152195
void ensureVectoredExceptionHandler() {
153196
static bool handler_installed = false;
154197
if (!handler_installed) {
@@ -191,12 +234,54 @@ VEHHookManager::add_breakpoint(HardwareBreakpoint bp,
191234
VEHHookManager::VEHHookHandler
192235
VEHHookManager::add_breakpoint(SoftwareBreakpoint bp,
193236
BreakpointCallback callback) {
194-
throw std::runtime_error("Software breakpoints are not supported yet.");
237+
ensureVectoredExceptionHandler();
238+
if (sw_breakpoints.contains(bp.address)) {
239+
throw std::runtime_error("Software breakpoint already exists at this address.");
240+
}
241+
242+
SoftwareBreakpointInformation info;
243+
info.bp = bp;
244+
info.callback = callback;
245+
246+
auto original_byte = Pointer(bp.address).try_read<uint8_t>();
247+
if (!original_byte) {
248+
throw std::runtime_error("Failed to read original byte for software breakpoint.");
249+
}
250+
info.original_bytes.push_back(*original_byte);
251+
252+
uint8_t int3 = 0xCC;
253+
if (!Pointer(bp.address).write(nullptr, std::span(&int3, 1))) {
254+
throw std::runtime_error("Failed to write INT3 for software breakpoint.");
255+
}
256+
257+
sw_breakpoints[bp.address] = std::move(info);
258+
return SoftwareBreakpointHandler{bp.address};
195259
}
196260
VEHHookManager::VEHHookHandler
197261
VEHHookManager::add_breakpoint(PagefaultBreakpoint bp,
198262
BreakpointCallback callback) {
199-
throw std::runtime_error("Pagefault breakpoints are not supported yet.");
263+
ensureVectoredExceptionHandler();
264+
if (pf_breakpoints.contains(bp.address)) {
265+
throw std::runtime_error("Pagefault breakpoint already exists at this address.");
266+
}
267+
268+
PagefaultBreakpointInformation info;
269+
info.bp = bp;
270+
info.callback = callback;
271+
272+
MEMORY_BASIC_INFORMATION mbi;
273+
if (VirtualQuery(bp.address, &mbi, sizeof(mbi)) == 0) {
274+
throw std::runtime_error("VirtualQuery failed for pagefault breakpoint.");
275+
}
276+
info.origin_protection = mbi.Protect;
277+
278+
DWORD old_protect;
279+
if (!VirtualProtect(bp.address, 1, info.origin_protection | PAGE_GUARD, &old_protect)) {
280+
throw std::runtime_error("VirtualProtect failed for pagefault breakpoint.");
281+
}
282+
283+
pf_breakpoints[bp.address] = std::move(info);
284+
return PagefaultBreakpointHandler{bp.address};
200285
}
201286
void VEHHookManager::remove_breakpoint(const VEHHookHandler &handler) {
202287
if (auto _hwbp = std::get_if<HardwareBreakpointHandler>(&handler)) {
@@ -206,7 +291,23 @@ void VEHHookManager::remove_breakpoint(const VEHHookHandler &handler) {
206291
}
207292
hw_breakpoints[_hwbp->dr_index].reset();
208293
sync_hw_breakpoints();
209-
} else {
294+
} else if (auto _swbp = std::get_if<SoftwareBreakpointHandler>(&handler)) {
295+
if (!sw_breakpoints.contains(_swbp->address)) {
296+
return;
297+
}
298+
auto& bp_info = sw_breakpoints.at(_swbp->address);
299+
Pointer(_swbp->address).write(nullptr, bp_info.original_bytes);
300+
sw_breakpoints.erase(_swbp->address);
301+
} else if (auto _pfbp = std::get_if<PagefaultBreakpointHandler>(&handler)) {
302+
if (!pf_breakpoints.contains(_pfbp->address)) {
303+
return;
304+
}
305+
auto& bp_info = pf_breakpoints.at(_pfbp->address);
306+
DWORD old_protect;
307+
VirtualProtect(_pfbp->address, 1, bp_info.origin_protection, &old_protect);
308+
pf_breakpoints.erase(_pfbp->address);
309+
}
310+
else {
210311
throw std::runtime_error("Unsupported breakpoint type.");
211312
}
212313
}

src/tests/test_windows.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,52 @@ TEST(VEHHookTest, HookMultipleFunctions) {
377377
ASSERT_FALSE(called4);
378378
}
379379

380+
TEST(VEHHookTest, SoftwareBreakpoint) {
381+
bool called = false;
382+
auto handler = blook::VEHHookManager::instance().add_breakpoint(
383+
blook::VEHHookManager::SoftwareBreakpoint{
384+
.address = (void *)veh_test_target_func},
385+
[&](blook::VEHHookManager::VEHHookContext &ctx) { called = true; });
386+
387+
ASSERT_TRUE(
388+
std::holds_alternative<blook::VEHHookManager::SoftwareBreakpointHandler>(
389+
handler));
390+
391+
veh_test_target_func(3, 5);
392+
393+
ASSERT_TRUE(called);
394+
395+
blook::VEHHookManager::instance().remove_breakpoint(handler);
396+
397+
called = false;
398+
auto result = veh_test_target_func(2, 4);
399+
EXPECT_EQ(result, 2 * 4 + 42);
400+
ASSERT_FALSE(called);
401+
}
402+
403+
// TEST(VEHHookTest, PagefaultBreakpoint) {
404+
// bool called = false;
405+
// auto handler = blook::VEHHookManager::instance().add_breakpoint(
406+
// blook::VEHHookManager::PagefaultBreakpoint{
407+
// .address = (void *)veh_test_target_func},
408+
// [&](blook::VEHHookManager::VEHHookContext &ctx) { called = true; });
409+
410+
// ASSERT_TRUE(
411+
// std::holds_alternative<blook::VEHHookManager::PagefaultBreakpointHandler>(
412+
// handler));
413+
414+
// veh_test_target_func(3, 5);
415+
416+
// ASSERT_TRUE(called);
417+
418+
// blook::VEHHookManager::instance().remove_breakpoint(handler);
419+
420+
// called = false;
421+
// auto result = veh_test_target_func(2, 4);
422+
// EXPECT_EQ(result, 2 * 4 + 42);
423+
// ASSERT_FALSE(called);
424+
// }
425+
380426
int main(int argc, char **argv) {
381427
::testing::InitGoogleTest(&argc, argv);
382428
return RUN_ALL_TESTS();

0 commit comments

Comments
 (0)