@@ -50,10 +50,10 @@ struct ColdStartSocketPair {
5050 int type;
5151};
5252
53- struct ColdStartState {
53+ struct SnapshotState {
5454 static constexpr uint32_t MAGIC = 0x564D4353 ; // 'VMCS'
5555 uint32_t magic;
56- uint32_t version ;
56+ uint32_t size ;
5757 tinykvm_x86regs regs;
5858 kvm_sregs sregs;
5959 tinykvm_x86fpuregs fpu;
@@ -83,31 +83,29 @@ struct ColdStartState {
8383 current = reinterpret_cast <char *>(current) + sizeof (T);
8484 // Bounds-check against end-of-structure
8585 if (reinterpret_cast <char *>(current) > reinterpret_cast <char *>(this ) + Size ()) {
86- throw std::runtime_error (" Out of bounds access on ColdStartState " );
86+ throw std::runtime_error (" Out of bounds access on SnapshotState " );
8787 }
8888 return ret;
8989 }
9090};
91- bool Machine::load_cold_start_state ()
91+ bool Machine::load_snapshot_state ()
9292{
93- if (!memory.has_loadable_cold_start_state ()) {
93+ if (!memory.has_loadable_snapshot_state ()) {
9494 return false ;
9595 }
96- if (!this ->memory .has_cold_start_area ) {
97- throw std::runtime_error (" No cold start state area allocated" );
96+ if (!this ->memory .has_snapshot_area ) {
97+ throw std::runtime_error (" No snapshot state area allocated" );
9898 }
9999 if (this ->is_forked ()) {
100- throw std::runtime_error (" Cannot load cold start state into a forked VM" );
100+ throw std::runtime_error (" Cannot load snapshot state into a forked VM" );
101101 }
102- void * map = this ->memory .get_cold_start_state_area ();
103- ColdStartState & state = *reinterpret_cast <ColdStartState *>(map);
104- if (state.magic != ColdStartState ::MAGIC) {
105- throw std::runtime_error (" No valid cold start state found" );
102+ void * map = this ->memory .get_snapshot_state_area ();
103+ SnapshotState & state = *reinterpret_cast <SnapshotState *>(map);
104+ if (state.magic != SnapshotState ::MAGIC) {
105+ throw std::runtime_error (" No valid snapshot state found" );
106106 }
107- if (state.version != 1 ) {
108- fprintf (stderr, " Warning: Cold start state version mismatch: %u != %u\n " ,
109- state.version , 1u );
110- return false ;
107+ if (state.size < sizeof (SnapshotState) || state.size > SnapshotState::Size ()) {
108+ throw std::runtime_error (" Invalid snapshot state size" );
111109 }
112110
113111 // Load the state into the VM
@@ -130,7 +128,7 @@ bool Machine::load_cold_start_state()
130128 this ->memory .main_memory_writes = state.main_memory_writes ;
131129 this ->memory .page_tables = state.m_page_tables ;
132130
133- void * current = reinterpret_cast <char *>(&state) + sizeof (ColdStartState );
131+ void * current = reinterpret_cast <char *>(&state) + sizeof (SnapshotState );
134132 // Load the thread states
135133 ColdStartThreads* threads = state.next <ColdStartThreads>(current);
136134 if (threads->count > 0 ) {
@@ -178,16 +176,16 @@ bool Machine::load_cold_start_state()
178176 }
179177 return true ;
180178}
181- void Machine::save_cold_start_state_now () const
179+ void Machine::save_snapshot_state_now () const
182180{
183181 if (this ->is_forked ()) {
184- throw std::runtime_error (" Cannot save cold start state of a forked VM" );
182+ throw std::runtime_error (" Cannot save snapshot state of a forked VM" );
185183 }
186- void * map = this ->memory .get_cold_start_state_area ();
187- ColdStartState & state = *reinterpret_cast <ColdStartState *>(map);
184+ void * map = this ->memory .get_snapshot_state_area ();
185+ SnapshotState & state = *reinterpret_cast <SnapshotState *>(map);
188186 try {
189- state.magic = ColdStartState ::MAGIC;
190- state.version = 1 ;
187+ state.magic = SnapshotState ::MAGIC;
188+ state.size = 0 ; // Invalid (for now)
191189 state.regs = this ->registers ();
192190 state.sregs = this ->get_special_registers ();
193191 state.fpu = this ->fpu_registers ();
@@ -206,7 +204,7 @@ void Machine::save_cold_start_state_now() const
206204 state.main_memory_writes = this ->memory .main_memory_writes ;
207205 state.m_page_tables = this ->memory .page_tables ;
208206
209- void * current = reinterpret_cast <char *>(&state) + sizeof (ColdStartState );
207+ void * current = reinterpret_cast <char *>(&state) + sizeof (SnapshotState );
210208 // Save the multi-threading state
211209 ColdStartThreads* threads = state.next <ColdStartThreads>(current);
212210 if (this ->has_threads ()) {
@@ -262,27 +260,50 @@ void Machine::save_cold_start_state_now() const
262260 csp->type = int (sp.type );
263261 }
264262
263+ // Finally, set the size
264+ state.size = static_cast <uint32_t >(
265+ reinterpret_cast <char *>(current) - reinterpret_cast <char *>(&state));
266+ if (state.size < sizeof (SnapshotState) || state.size > SnapshotState::Size ()) {
267+ throw std::runtime_error (" Snapshot state size was invalid" );
268+ }
269+
265270 } catch (const MachineException& me) {
266271 throw std::runtime_error (std::string (" Failed to get cold start state: " ) + me.what ());
267272 }
268273}
269274
270- void * vMemory::get_cold_start_state_area () const
275+ void * vMemory::get_snapshot_state_area () const
271276{
272- if (!this ->has_cold_start_area ) {
277+ if (!this ->has_snapshot_area ) {
273278 throw std::runtime_error (" No cold start state area allocated" );
274279 }
275280 // The cold start state area is after the end of the memory
276281 return (void *)(this ->ptr + this ->size );
277282}
278- bool vMemory::has_loadable_cold_start_state () const noexcept
283+ bool vMemory::has_loadable_snapshot_state () const noexcept
279284{
280- if (this ->has_cold_start_area ) {
281- void * area = this ->get_cold_start_state_area ();
285+ if (this ->has_snapshot_area ) {
286+ void * area = this ->get_snapshot_state_area ();
282287 uint32_t * magic = reinterpret_cast <uint32_t *>(area);
283- return *magic == ColdStartState ::MAGIC;
288+ return *magic == SnapshotState ::MAGIC;
284289 }
285290 return false ;
286291}
292+ void * Machine::get_snapshot_state_user_area () const
293+ {
294+ if (!this ->memory .has_snapshot_area ) {
295+ return nullptr ;
296+ }
297+ void * map = this ->memory .get_snapshot_state_area ();
298+ SnapshotState& state = *reinterpret_cast <SnapshotState*>(map);
299+ if (state.magic != SnapshotState::MAGIC) {
300+ return nullptr ;
301+ }
302+ if (state.size < sizeof (SnapshotState) || state.size > SnapshotState::Size ()) {
303+ return nullptr ;
304+ }
305+ // The user area is after the SnapshotState + size
306+ return reinterpret_cast <char *>(map) + sizeof (SnapshotState) + state.size ;
307+ }
287308
288309} // namespace tinykvm
0 commit comments