Skip to content

Commit 247d9be

Browse files
committed
fix(ipc): Add timeout to AwaitMessage to prevent indefinite blocking
- Add IPCAcceptTimeout (60s) and IPCReadTimeout (10s) to prevent orphaned processes when counterpart never connects - Fix closure bug in executeHooksConcurrently using wrong loop variable - Fix isRunning() using annotType instead of annotHypervisor - Add tests for timeout and wrong message handling Signed-off-by: Aman-Cool <aman017102007@gmail.com>
1 parent 66f01ee commit 247d9be

File tree

3 files changed

+90
-10
lines changed

3 files changed

+90
-10
lines changed

pkg/unikontainers/ipc.go

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ const (
4141
maxRetries = 50
4242
waitTime = 5 * time.Millisecond
4343
FromReexec = true
44+
// IPCAcceptTimeout is the maximum time to wait for a connection on the IPC socket.
45+
// This prevents processes from hanging indefinitely if the counterpart never connects
46+
// (e.g., due to containerd restart, node pressure, or orchestration failures).
47+
IPCAcceptTimeout = 60 * time.Second
48+
// IPCReadTimeout is the maximum time to wait for reading a message after connection.
49+
IPCReadTimeout = 10 * time.Second
4450
)
4551

4652
func getSockAddr(dir string, name string) string {
@@ -145,27 +151,48 @@ func createListener(socketAddress string, mustBeValid bool) (*net.UnixListener,
145151
return listener, nil
146152
}
147153

148-
// awaitMessage opens a new connection to socketAddress
149-
// and waits for a given message
154+
// AwaitMessage waits for a connection on the listener and reads an expected message.
155+
// It uses timeouts to prevent indefinite blocking if the counterpart process
156+
// never connects (e.g., due to orchestration failures, crashes, or restarts).
150157
func AwaitMessage(listener *net.UnixListener, expectedMessage IPCMessage) error {
158+
// Set accept deadline to prevent indefinite blocking.
159+
// This is critical for preventing orphaned processes when urunc start
160+
// never runs after urunc create, or when reexec fails silently.
161+
if err := listener.SetDeadline(time.Now().Add(IPCAcceptTimeout)); err != nil {
162+
return fmt.Errorf("failed to set listener deadline: %w", err)
163+
}
164+
151165
conn, err := listener.AcceptUnix()
152166
if err != nil {
153-
return err
167+
var netErr net.Error
168+
if errors.As(err, &netErr) && netErr.Timeout() {
169+
return fmt.Errorf("timeout waiting for IPC connection (waited %v): counterpart process may have failed or not started", IPCAcceptTimeout)
170+
}
171+
return fmt.Errorf("failed to accept connection: %w", err)
154172
}
155173
defer func() {
156-
err = conn.Close()
157-
if err != nil {
158-
logrus.WithError(err).Error("failed to close connection")
174+
if closeErr := conn.Close(); closeErr != nil {
175+
logrus.WithError(closeErr).Error("failed to close connection")
159176
}
160177
}()
178+
179+
// Set read deadline to prevent hanging on slow or stuck writers
180+
if err := conn.SetReadDeadline(time.Now().Add(IPCReadTimeout)); err != nil {
181+
return fmt.Errorf("failed to set read deadline: %w", err)
182+
}
183+
161184
buf := make([]byte, len(expectedMessage))
162185
n, err := conn.Read(buf)
163186
if err != nil {
187+
var netErr net.Error
188+
if errors.As(err, &netErr) && netErr.Timeout() {
189+
return fmt.Errorf("timeout reading IPC message (waited %v): counterpart process may be stuck", IPCReadTimeout)
190+
}
164191
return fmt.Errorf("failed to read from socket: %w", err)
165192
}
166193
msg := string(buf[0:n])
167194
if msg != string(expectedMessage) {
168-
return fmt.Errorf("received unexpected message: %s", msg)
195+
return fmt.Errorf("received unexpected message: %s (expected: %s)", msg, expectedMessage)
169196
}
170197
return nil
171198
}

pkg/unikontainers/ipc_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,56 @@ func TestAwaitMessage(t *testing.T) {
176176
err = AwaitMessage(listener, expectedMessage)
177177
assert.NoError(t, err, "Expected no error in awaiting message")
178178
}
179+
180+
func TestAwaitMessageTimeout(t *testing.T) {
181+
socketAddress := "/tmp/test_await_message_timeout.sock"
182+
expectedMessage := ReexecStarted
183+
184+
listener, err := createListener(socketAddress, true)
185+
if err != nil {
186+
t.Fatalf("Failed to create listener: %v", err)
187+
}
188+
defer listener.Close()
189+
190+
// Don't send any message - this should trigger a timeout
191+
// Note: For testing, we need shorter timeouts than production.
192+
// The actual timeout check is that it returns an error containing "timeout"
193+
// rather than blocking forever.
194+
195+
// Set a shorter deadline for testing purposes
196+
listener.SetDeadline(time.Now().Add(100 * time.Millisecond))
197+
198+
err = AwaitMessage(listener, expectedMessage)
199+
assert.Error(t, err, "Expected timeout error when no connection arrives")
200+
assert.Contains(t, err.Error(), "timeout", "Expected error message to mention timeout")
201+
}
202+
203+
func TestAwaitMessageWrongMessage(t *testing.T) {
204+
socketAddress := "/tmp/test_await_wrong_message.sock"
205+
expectedMessage := ReexecStarted
206+
wrongMessage := StartExecve
207+
208+
listener, err := createListener(socketAddress, true)
209+
if err != nil {
210+
t.Fatalf("Failed to create listener: %v", err)
211+
}
212+
defer listener.Close()
213+
214+
go func() {
215+
conn, err := net.Dial("unix", socketAddress)
216+
if err != nil {
217+
t.Errorf("Failed to dial connection: %v", err)
218+
}
219+
defer conn.Close()
220+
221+
// Send wrong message
222+
_, err = conn.Write([]byte(wrongMessage))
223+
if err != nil {
224+
t.Errorf("Failed to send message: %v", err)
225+
}
226+
}()
227+
228+
err = AwaitMessage(listener, expectedMessage)
229+
assert.Error(t, err, "Expected error for unexpected message")
230+
assert.Contains(t, err.Error(), "unexpected message", "Expected error to mention unexpected message")
231+
}

pkg/unikontainers/unikontainers.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -753,8 +753,8 @@ func (u *Unikontainer) executeHooksConcurrently(name string, hooks []specs.Hook,
753753
uniklog.WithFields(logrus.Fields{
754754
"id": u.State.ID,
755755
"name": name,
756-
"path": hooks[i].Path,
757-
"args": hooks[i].Args,
756+
"path": h.Path,
757+
"args": h.Args,
758758
"error": err,
759759
}).Error("Executing hook failed")
760760
errChan <- err
@@ -1121,7 +1121,7 @@ func (u *Unikontainer) SendMessage(message IPCMessage) error {
11211121

11221122
// isRunning returns true if the PID is alive or hedge.ListVMs returns our containerID
11231123
func (u *Unikontainer) isRunning() bool {
1124-
vmmType := hypervisors.VmmType(u.State.Annotations[annotType])
1124+
vmmType := hypervisors.VmmType(u.State.Annotations[annotHypervisor])
11251125
if vmmType != hypervisors.HedgeVmm {
11261126
return syscall.Kill(u.State.Pid, syscall.Signal(0)) == nil
11271127
}

0 commit comments

Comments
 (0)