Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 175 additions & 25 deletions crates/libcontainer/src/process/channel.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
use std::collections::HashMap;
use std::os::unix::prelude::{AsRawFd, RawFd};
use std::os::unix::prelude::{AsRawFd, FromRawFd, OwnedFd, RawFd};

use nix::unistd::Pid;

use crate::channel::{Receiver, Sender, channel};
use crate::network::cidr::CidrAddress;
use crate::process::message::Message;
use crate::process::message::{Message, MountMsg};

#[derive(Debug, thiserror::Error)]
pub enum ChannelError {
#[error("received unexpected message: {received:?}, expected: {expected:?}")]
#[error("received unexpected message: {received:?}, expected: {expected}")]
UnexpectedMessage {
expected: Message,
received: Message,
expected: &'static str,
received: Box<Message>,
},
#[error("failed to receive. {msg:?}. {source:?}")]
ReceiveError {
Expand All @@ -28,6 +28,10 @@ pub enum ChannelError {
ExecError(String),
#[error("intermediate process error {0}")]
OtherError(String),
#[error("missing fd from mount request")]
MissingMountFds,
#[error("mount request failed: {0}")]
MountFdError(String),
}

// Channel Design
Expand Down Expand Up @@ -67,6 +71,12 @@ impl MainSender {
Ok(())
}

pub fn request_mount_fd(&mut self, msg: MountMsg) -> Result<(), ChannelError> {
self.sender.send(Message::MountFdPlease(msg))?;

Ok(())
}

pub fn network_setup_ready(&mut self) -> Result<(), ChannelError> {
tracing::debug!("notify network setup ready");
self.sender.send(Message::SetupNetworkDeviceReady)?;
Expand Down Expand Up @@ -131,8 +141,8 @@ impl MainReceiver {
Message::ExecFailed(err) => Err(ChannelError::ExecError(err)),
Message::OtherError(err) => Err(ChannelError::OtherError(err)),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::IntermediateReady(0),
received: msg,
expected: "IntermediateReady",
received: Box::new(msg),
}),
}
}
Expand All @@ -148,12 +158,39 @@ impl MainReceiver {
match msg {
Message::WriteMapping => Ok(()),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::WriteMapping,
received: msg,
expected: "WriteMapping",
received: Box::new(msg),
}),
}
}

pub fn wait_for_mount_fd_request(&mut self) -> Result<MountMsg, ChannelError> {
let msg = self
.receiver
.recv()
.map_err(|err| ChannelError::ReceiveError {
msg: "waiting for mount fd request".to_string(),
source: err,
})?;

match msg {
Message::MountFdPlease(req) => Ok(req),
msg => Err(ChannelError::UnexpectedMessage {
expected: "MountFdPlease",
received: Box::new(msg),
}),
}
}

pub fn recv_message_with_fds(&mut self) -> Result<(Message, Option<[RawFd; 1]>), ChannelError> {
self.receiver
.recv_with_fds::<[RawFd; 1]>()
.map_err(|err| ChannelError::ReceiveError {
msg: "waiting for message".to_string(),
source: err,
})
}

pub fn wait_for_seccomp_request(&mut self) -> Result<i32, ChannelError> {
let (msg, fds) = self.receiver.recv_with_fds::<[RawFd; 1]>().map_err(|err| {
ChannelError::ReceiveError {
Expand All @@ -177,8 +214,8 @@ impl MainReceiver {
Ok(fd)
}
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::SeccompNotify,
received: msg,
expected: "SeccompNotify",
received: Box::new(msg),
}),
}
}
Expand All @@ -194,8 +231,8 @@ impl MainReceiver {
match msg {
Message::SetupNetworkDeviceReady => Ok(()),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::SetupNetworkDeviceReady,
received: msg,
expected: "SetupNetworkDeviceReady",
received: Box::new(msg),
}),
}
}
Expand All @@ -217,8 +254,8 @@ impl MainReceiver {
"error in executing process : {err}"
))),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::InitReady,
received: msg,
expected: "InitReady",
received: Box::new(msg),
}),
}
}
Expand All @@ -234,8 +271,8 @@ impl MainReceiver {
match msg {
Message::HookRequest => Ok(()),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::HookRequest,
received: msg,
expected: "HookRequest",
received: Box::new(msg),
}),
}
}
Expand Down Expand Up @@ -292,8 +329,8 @@ impl IntermediateReceiver {
match msg {
Message::MappingWritten => Ok(()),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::MappingWritten,
received: msg,
expected: "MappingWritten",
received: Box::new(msg),
}),
}
}
Expand Down Expand Up @@ -340,6 +377,17 @@ impl InitSender {

Ok(())
}

pub fn send_mount_fd_reply(&mut self, fd: RawFd) -> Result<(), ChannelError> {
self.sender.send_fds(Message::MountFdReply, &[fd])?;

Ok(())
}

pub fn send_mount_fd_error(&mut self, err: String) -> Result<(), ChannelError> {
self.sender.send(Message::MountFdError(err))?;
Ok(())
}
}

pub struct InitReceiver {
Expand All @@ -359,8 +407,8 @@ impl InitReceiver {
match msg {
Message::SeccompNotifyDone => Ok(()),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::SeccompNotifyDone,
received: msg,
expected: "SeccompNotifyDone",
received: Box::new(msg),
}),
}
}
Expand All @@ -378,8 +426,8 @@ impl InitReceiver {
match msg {
Message::MoveNetworkDevice(addr) => Ok(addr),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::WriteMapping,
received: msg,
expected: "MoveNetworkDevice",
received: Box::new(msg),
}),
}
}
Expand All @@ -395,8 +443,8 @@ impl InitReceiver {
match msg {
Message::HookDone => Ok(()),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::HookDone,
received: msg,
expected: "HookDone",
received: Box::new(msg),
}),
}
}
Expand All @@ -406,10 +454,37 @@ impl InitReceiver {

Ok(())
}

pub fn wait_for_mount_fd_reply(&mut self) -> Result<OwnedFd, ChannelError> {
let (msg, fds) = self.receiver.recv_with_fds::<[RawFd; 1]>().map_err(|err| {
ChannelError::ReceiveError {
msg: "waiting for mount fd reply".to_string(),
source: err,
}
})?;

match msg {
Message::MountFdReply => {
let fd = match fds {
Some([fd]) => fd,
_ => return Err(ChannelError::MissingMountFds),
};
Ok(unsafe { OwnedFd::from_raw_fd(fd) })
}
Message::MountFdError(err) => Err(ChannelError::MountFdError(err)),
msg => Err(ChannelError::UnexpectedMessage {
expected: "MountFdReply",
received: Box::new(msg),
}),
}
}
}

#[cfg(test)]
mod tests {
use std::io::{Read, Seek, SeekFrom, Write};
use std::os::fd::AsRawFd;

use anyhow::{Context, Result};
use nix::sys::wait;
use nix::unistd;
Expand Down Expand Up @@ -490,6 +565,81 @@ mod tests {
Ok(())
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test for wait_for_mount_fd_request?

#[test]
#[serial]
fn test_channel_mount_fd_error() -> Result<()> {
let (sender, receiver) = &mut init_channel()?;
sender.send_mount_fd_error("boom".to_string())?;
let err = receiver.wait_for_mount_fd_reply().unwrap_err();
assert!(matches!(err, ChannelError::MountFdError(msg) if msg == "boom"));
sender.close()?;
receiver.close()?;
Ok(())
}

#[test]
#[serial]
fn test_channel_mount_fd_reply_success() -> Result<()> {
let (sender, receiver) = &mut init_channel()?;
let mut file = tempfile::tempfile()?;
file.write_all(b"ok")?;

sender.send_mount_fd_reply(file.as_raw_fd())?;
let fd = receiver.wait_for_mount_fd_reply()?;
let mut received = std::fs::File::from(fd);
received.seek(SeekFrom::Start(0))?;
let mut buf = String::new();
received.read_to_string(&mut buf)?;
assert_eq!(buf, "ok");

sender.close()?;
receiver.close()?;
Ok(())
}

#[test]
#[serial]
fn test_channel_mount_fd_reply_missing_fds() -> Result<()> {
let (mut sender, receiver) = channel::<Message>()?;
let mut receiver = InitReceiver { receiver };

sender.send(Message::MountFdReply)?;
let err = receiver.wait_for_mount_fd_reply().unwrap_err();
assert!(matches!(err, ChannelError::MissingMountFds));

sender.close()?;
receiver.close()?;
Ok(())
}

#[test]
#[serial]
fn test_channel_mount_fd_request() -> Result<()> {
let (sender, receiver) = &mut main_channel()?;
let request = MountMsg {
source: "/proc/self/ns/user".to_string(),
idmap: Some(crate::process::message::MountIdMap {
uid_mappings: vec![],
gid_mappings: vec![],
recursive: true,
}),
};

sender.request_mount_fd(request.clone())?;
let received = receiver.wait_for_mount_fd_request()?;

assert_eq!(received.source, request.source);
let received_idmap = received.idmap.context("missing idmap in mount request")?;
let request_idmap = request.idmap.context("missing idmap in mount request")?;
assert_eq!(received_idmap.recursive, request_idmap.recursive);
assert_eq!(received_idmap.uid_mappings, request_idmap.uid_mappings);
assert_eq!(received_idmap.gid_mappings, request_idmap.gid_mappings);

sender.close()?;
receiver.close()?;
Ok(())
}

#[test]
#[serial]
fn test_channel_init_ready() -> Result<()> {
Expand Down
20 changes: 20 additions & 0 deletions crates/libcontainer/src/process/message.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use core::fmt;
use std::collections::HashMap;

use oci_spec::runtime::LinuxIdMapping;
use serde::{Deserialize, Serialize};

use crate::network::cidr::CidrAddress;
Expand All @@ -16,8 +17,11 @@ pub enum Message {
SeccompNotifyDone,
SetupNetworkDeviceReady,
MoveNetworkDevice(HashMap<String, Vec<CidrAddress>>),
MountFdPlease(MountMsg),
MountFdReply,
ExecFailed(String),
OtherError(String),
MountFdError(String),
HookRequest,
HookDone,
}
Expand All @@ -35,8 +39,24 @@ impl fmt::Display for Message {
Message::SeccompNotifyDone => write!(f, "SeccompNotifyDone"),
Message::HookRequest => write!(f, "HookRequest"),
Message::HookDone => write!(f, "HookDone"),
Message::MountFdPlease(_) => write!(f, "MountFdPlease"),
Message::MountFdReply => write!(f, "MountFdReply"),
Message::MountFdError(err) => write!(f, "MountFdError({})", err),
Message::ExecFailed(s) => write!(f, "ExecFailed({})", s),
Message::OtherError(s) => write!(f, "OtherError({})", s),
}
}
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MountMsg {
pub source: String,
pub idmap: Option<MountIdMap>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MountIdMap {
pub uid_mappings: Vec<LinuxIdMapping>,
pub gid_mappings: Vec<LinuxIdMapping>,
pub recursive: bool,
}
Loading