@@ -29,6 +29,7 @@ import "C"
2929
3030import (
3131 "encoding/binary"
32+ "encoding/json"
3233 "errors"
3334 "fmt"
3435 "net"
@@ -39,7 +40,6 @@ import (
3940 "syscall"
4041 "time"
4142
42- "github.com/ftrvxmtrx/fd"
4343 log "github.com/sirupsen/logrus"
4444 "golang.org/x/sys/unix"
4545
@@ -133,44 +133,49 @@ func (s *SnapshotState) setupStateOnActivate() {
133133 }
134134}
135135
136- func (s * SnapshotState ) getUFFD (sendfdConn * net.UnixConn ) error {
137- // var d net.Dialer
138- // ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
139- // defer cancel()
140-
141- // for {
142- // c, err := d.DialContext(ctx, "unix", s.InstanceSockAddr)
143- // if err != nil {
144- // if ctx.Err() != nil {
145- // log.Error("Failed to dial within the context timeout")
146- // return err
147- // }
148- // time.Sleep(1 * time.Millisecond)
149- // continue
150- // }
151-
152- // defer c.Close()
153-
154- // sendfdConn := c.(*net.UnixConn)
136+ type GuestRegionUffdMapping struct {
137+ BaseHostVirtAddr uint64 `json:"base_host_virt_addr"`
138+ Size uint64 `json:"size"`
139+ Offset uint64 `json:"offset"`
140+ PageSizeKiB uint64 `json:"page_size_kib"`
141+ }
155142
156- // fs, err := fd.Get(sendfdConn, 1, []string{"a file"})
157- // if err != nil {
158- // log.Error("Failed to receive the uffd")
159- // return err
160- // }
143+ func (s * SnapshotState ) getUFFD (sendfdConn * net.UnixConn ) error {
144+ buff := make ([]byte , 256 ) // set a maximum buffer size
145+ oobBuff := make ([]byte , unix .CmsgSpace (4 ))
161146
162- // s.userFaultFD = fs[0]
147+ n , oobn , _ , _ , err := sendfdConn .ReadMsgUnix (buff , oobBuff )
148+ if err != nil {
149+ return fmt .Errorf ("error reading message: %w" , err )
150+ }
151+ buff = buff [:n ]
163152
164- // return nil
165- // }
153+ var fd int
154+ if oobn > 0 {
155+ scms , err := unix .ParseSocketControlMessage (oobBuff [:oobn ])
156+ if err != nil {
157+ return fmt .Errorf ("error parsing socket control message: %w" , err )
158+ }
159+ for _ , scm := range scms {
160+ fds , err := unix .ParseUnixRights (& scm )
161+ if err != nil {
162+ return fmt .Errorf ("error parsing unix rights: %w" , err )
163+ }
164+ if len (fds ) > 0 {
165+ fd = fds [0 ] // Assuming only one fd is sent.
166+ break
167+ }
168+ }
169+ }
170+ userfaultFD := os .NewFile (uintptr (fd ), "userfaultfd" )
166171
167- fs , err := fd .Get (sendfdConn , 1 , []string {"a file" })
168- if err != nil {
169- log .Error ("Failed to receive the uffd" )
170- return err
172+ var mapping []GuestRegionUffdMapping
173+ if err := json .Unmarshal (buff , & mapping ); err != nil {
174+ return fmt .Errorf ("error unmarshaling data: %w" , err )
171175 }
172176
173- s .userFaultFD = fs [0 ]
177+ s .startAddress = mapping [0 ].BaseHostVirtAddr
178+ s .userFaultFD = userfaultFD
174179 return nil
175180}
176181
@@ -401,7 +406,6 @@ func (s *SnapshotState) servePageFault(fd int, address uint64) error {
401406
402407 s .firstPageFaultOnce .Do (
403408 func () {
404- s .startAddress = address
405409 log .Debugf ("TEST: first page fault address %d" , address )
406410
407411 if s .isRecordReady && ! s .IsLazyMode {
0 commit comments