11#include " machine.hpp"
22
3+ #include < cstring>
34#include < fcntl.h>
45#include < linux/kvm.h>
56#include < stdexcept>
7+ #include < span>
68#include < sys/mman.h>
79#include < sys/stat.h>
810#include < unistd.h>
911#ifdef TINYKVM_ARCH_AMD64
1012#include " amd64/amd64.hpp"
13+ #include " amd64/paging.hpp"
1114#endif
1215#include " linux/fds.hpp"
1316#include " linux/threads.hpp"
1417
1518namespace tinykvm {
1619
20+ struct ColdStartAccessedRange {
21+ uint64_t start;
22+ uint64_t end;
23+ };
24+
1725struct ColdStartThreadState {
1826 int tid;
1927 tinykvm_x86regs regs;
@@ -73,10 +81,13 @@ struct SnapshotState {
7381 // MMapCache m_mmap_cache;
7482 Machine::address_t mmap_current;
7583
76- bool main_memory_writes;
7784 Machine::address_t m_page_tables;
85+ bool main_memory_writes;
86+ uint32_t num_access_ranges;
7887
79- static constexpr size_t Size () noexcept { return 4096ul ; }
88+ char current[0 ];
89+
90+ static constexpr size_t Size () noexcept { return vMemory::ColdStartStateSize (); }
8091
8192 template <typename T>
8293 T* next (void *& current) {
@@ -106,6 +117,7 @@ bool Machine::load_snapshot_state()
106117 throw std::runtime_error (" No valid snapshot state found" );
107118 }
108119 if (state.size < sizeof (SnapshotState) || state.size > SnapshotState::Size ()) {
120+ fprintf (stderr, " Invalid snapshot state size: %u\n " , state.size );
109121 throw std::runtime_error (" Invalid snapshot state size" );
110122 }
111123
@@ -129,7 +141,22 @@ bool Machine::load_snapshot_state()
129141 this ->memory .main_memory_writes = state.main_memory_writes ;
130142 this ->memory .page_tables = state.m_page_tables ;
131143
132- void * current = reinterpret_cast <char *>(&state) + sizeof (SnapshotState);
144+ void * current = state.current ;
145+ // Load populate pages
146+ for (unsigned i = 0 ; i < state.num_access_ranges ; i++) {
147+ ColdStartAccessedRange* range = state.next <ColdStartAccessedRange>(current);
148+ if (range->start >= MemoryBanks::ARENA_BASE_ADDRESS)
149+ continue ;
150+ try {
151+ printf (" Populating pages from 0x%lX -> 0x%lX\n " , range->start , range->end );
152+ char * page = this ->memory .get_userpage_at (range->start );
153+ madvise (page, range->end - range->start , MADV_WILLNEED);
154+ } catch (const std::exception& e) {
155+ fprintf (stderr, " Failed to access page at 0x%lX: %s\n " , range->start , e.what ());
156+ continue ;
157+ }
158+ }
159+
133160 // Load the thread states
134161 ColdStartThreads* threads = state.next <ColdStartThreads>(current);
135162 if (threads->count > 0 ) {
@@ -181,7 +208,7 @@ bool Machine::load_snapshot_state()
181208 }
182209 return true ;
183210}
184- void Machine::save_snapshot_state_now () const
211+ void Machine::save_snapshot_state_now (const std::vector< uint64_t >& populate_pages ) const
185212{
186213 if (this ->is_forked ()) {
187214 throw std::runtime_error (" Cannot save snapshot state of a forked VM" );
@@ -209,7 +236,40 @@ void Machine::save_snapshot_state_now() const
209236 state.main_memory_writes = this ->memory .main_memory_writes ;
210237 state.m_page_tables = this ->memory .page_tables ;
211238
212- void * current = reinterpret_cast <char *>(&state) + sizeof (SnapshotState);
239+ void * current = state.current ;
240+ // Save populate pages
241+ state.num_access_ranges = 0 ;
242+ if (!populate_pages.empty ()) {
243+ uint64_t current_begin = 0 ;
244+ uint64_t current_end = 0 ;
245+ for (uint64_t page_addr : populate_pages) {
246+ if (page_addr >= MemoryBanks::ARENA_BASE_ADDRESS)
247+ continue ;
248+ // Merge contiguous ranges
249+ if (current_end == page_addr) {
250+ current_end += vMemory::PageSize ();
251+ continue ;
252+ }
253+ // Store previous range
254+ if (current_end != current_begin) {
255+ ColdStartAccessedRange* range = state.next <ColdStartAccessedRange>(current);
256+ range->start = current_begin;
257+ range->end = current_end;
258+ state.num_access_ranges ++;
259+ }
260+ // Start new range
261+ current_begin = page_addr;
262+ current_end = page_addr;
263+ }
264+ // Store last range
265+ if (current_end != current_begin) {
266+ ColdStartAccessedRange* range = state.next <ColdStartAccessedRange>(current);
267+ range->start = current_begin;
268+ range->end = current_end;
269+ state.num_access_ranges ++;
270+ }
271+ }
272+
213273 // Save the multi-threading state
214274 ColdStartThreads* threads = state.next <ColdStartThreads>(current);
215275 if (this ->has_threads ()) {
@@ -274,7 +334,14 @@ void Machine::save_snapshot_state_now() const
274334 }
275335
276336 } catch (const MachineException& me) {
337+ fprintf (stderr, " Failed to get snapshot state: %s Data: 0x%#lX\n " ,
338+ me.what (), me.data ());
339+ state.magic = 0 ; // Invalidate
277340 throw std::runtime_error (std::string (" Failed to get snapshot state: " ) + me.what ());
341+ } catch (const std::exception& e) {
342+ fprintf (stderr, " Failed to get snapshot state: %s\n " , e.what ());
343+ state.magic = 0 ; // Invalidate
344+ throw ;
278345 }
279346}
280347
0 commit comments