Skip to content

std: make address resolution weirdness local to SGX #145327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions library/std/src/io/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ impl Error {

pub(crate) const ZERO_TIMEOUT: Self =
const_error!(ErrorKind::InvalidInput, "cannot set a 0 duration timeout");

pub(crate) const NO_ADDRESSES: Self =
const_error!(ErrorKind::InvalidInput, "could not resolve to any addresses");
}

#[stable(feature = "rust1", since = "1.0.0")]
Expand Down
21 changes: 0 additions & 21 deletions library/std/src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ pub use self::tcp::IntoIncoming;
pub use self::tcp::{Incoming, TcpListener, TcpStream};
#[stable(feature = "rust1", since = "1.0.0")]
pub use self::udp::UdpSocket;
use crate::io::{self, ErrorKind};

mod ip_addr;
mod socket_addr;
Expand Down Expand Up @@ -67,23 +66,3 @@ pub enum Shutdown {
#[stable(feature = "rust1", since = "1.0.0")]
Both,
}

fn each_addr<A: ToSocketAddrs, F, T>(addr: A, mut f: F) -> io::Result<T>
where
F: FnMut(io::Result<&SocketAddr>) -> io::Result<T>,
{
let addrs = match addr.to_socket_addrs() {
Ok(addrs) => addrs,
Err(e) => return f(Err(e)),
};
let mut last_err = None;
for addr in addrs {
match f(Ok(&addr)) {
Ok(l) => return Ok(l),
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or_else(|| {
io::const_error!(ErrorKind::InvalidInput, "could not resolve to any addresses")
}))
}
4 changes: 2 additions & 2 deletions library/std/src/net/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ impl TcpStream {
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> {
super::each_addr(addr, net_imp::TcpStream::connect).map(TcpStream)
net_imp::TcpStream::connect(addr).map(TcpStream)
}

/// Opens a TCP connection to a remote host with a timeout.
Expand Down Expand Up @@ -782,7 +782,7 @@ impl TcpListener {
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<TcpListener> {
super::each_addr(addr, net_imp::TcpListener::bind).map(TcpListener)
net_imp::TcpListener::bind(addr).map(TcpListener)
}

/// Returns the local socket address of this listener.
Expand Down
2 changes: 1 addition & 1 deletion library/std/src/net/tcp/tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::io::prelude::*;
use crate::io::{BorrowedBuf, IoSlice, IoSliceMut};
use crate::io::{BorrowedBuf, ErrorKind, IoSlice, IoSliceMut};
use crate::mem::MaybeUninit;
use crate::net::test::{next_test_ip4, next_test_ip6};
use crate::net::*;
Expand Down
4 changes: 2 additions & 2 deletions library/std/src/net/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl UdpSocket {
/// [`Ipv4Addr::UNSPECIFIED`] or [`Ipv6Addr::UNSPECIFIED`].
#[stable(feature = "rust1", since = "1.0.0")]
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
super::each_addr(addr, net_imp::UdpSocket::bind).map(UdpSocket)
net_imp::UdpSocket::bind(addr).map(UdpSocket)
}

/// Receives a single datagram message on the socket. On success, returns the number
Expand Down Expand Up @@ -677,7 +677,7 @@ impl UdpSocket {
/// on the platform.
#[stable(feature = "net2_mutators", since = "1.9.0")]
pub fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> {
super::each_addr(addr, |addr| self.0.connect(addr))
self.0.connect(addr)
}

/// Sends data on the socket to the remote address to which it is connected.
Expand Down
1 change: 1 addition & 0 deletions library/std/src/net/udp/tests.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::io::ErrorKind;
use crate::net::test::{compare_ignore_zoneid, next_test_ip4, next_test_ip6};
use crate::net::*;
use crate::sync::mpsc::channel;
Expand Down
52 changes: 52 additions & 0 deletions library/std/src/sys/net/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
cfg_if::cfg_if! {
if #[cfg(any(
all(target_family = "unix", not(target_os = "l4re")),
target_os = "windows",
target_os = "hermit",
all(target_os = "wasi", target_env = "p2"),
target_os = "solid_asp3",
))] {
mod socket;
pub use socket::*;
} else if #[cfg(all(target_vendor = "fortanix", target_env = "sgx"))] {
mod sgx;
pub use sgx::*;
} else if #[cfg(all(target_os = "wasi", target_env = "p1"))] {
mod wasip1;
pub use wasip1::*;
} else if #[cfg(target_os = "xous")] {
mod xous;
pub use xous::*;
} else if #[cfg(target_os = "uefi")] {
mod uefi;
pub use uefi::*;
} else {
mod unsupported;
pub use unsupported::*;
}
}

#[cfg_attr(
// Make sure that this is used on some platforms at least.
not(any(target_os = "linux", target_os = "windows")),
allow(dead_code)
)]
fn each_addr<A: crate::net::ToSocketAddrs, F, T>(addr: A, mut f: F) -> crate::io::Result<T>
where
F: FnMut(&crate::net::SocketAddr) -> crate::io::Result<T>,
{
use crate::io::Error;

let mut last_err = None;
for addr in addr.to_socket_addrs()? {
match f(&addr) {
Ok(l) => return Ok(l),
Err(e) => last_err = Some(e),
}
}

match last_err {
Some(err) => Err(err),
None => Err(Error::NO_ADDRESSES),
}
}
80 changes: 58 additions & 22 deletions library/std/src/sys/net/connection/sgx.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::error;
use crate::fmt::{self, Write};
use crate::io::{self, BorrowedCursor, IoSlice, IoSliceMut};
use crate::net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs};
use crate::sync::Arc;
use crate::sys::abi::usercalls;
use crate::sys::fd::FileDesc;
use crate::sys::{AsInner, FromInner, IntoInner, TryIntoInner, sgx_ineffective, unsupported};
use crate::time::Duration;
use crate::{error, fmt};

const DEFAULT_FAKE_TTL: u32 = 64;

Expand Down Expand Up @@ -63,18 +64,51 @@ impl fmt::Debug for TcpStream {
}
}

fn io_err_to_addr(result: io::Result<&SocketAddr>) -> io::Result<String> {
match result {
Ok(saddr) => Ok(saddr.to_string()),
// need to downcast twice because io::Error::into_inner doesn't return the original
// value if the conversion fails
Err(e) => {
if e.get_ref().and_then(|e| e.downcast_ref::<NonIpSockAddr>()).is_some() {
Ok(e.into_inner().unwrap().downcast::<NonIpSockAddr>().unwrap().host)
} else {
Err(e)
/// Converts each address in `addr` into a hostname.
///
/// SGX doesn't support DNS resolution but rather accepts hostnames in
/// the same place as socket addresses. So, to make e.g.
/// ```rust
/// TcpStream::connect("example.com:80")`
/// ```
/// work, the DNS lookup returns a special error (`NonIpSockAddr`) instead,
/// which contains the hostname being looked up. When `.to_socket_addrs()`
/// fails, we inspect the error and try recover the hostname from it. If that
/// succeeds, we thus continue with the hostname.
///
/// This is a terrible hack and leads to buggy code. For instance, when users
/// use the result of `.to_socket_addrs()` in their own `ToSocketAddrs`
/// implementation to select from a list of possible URLs, the only URL used
/// will be that of the last item tried.
// FIXME: This is a terrible, terrible hack.
fn each_addr<A: ToSocketAddrs, F, T>(addr: A, mut f: F) -> io::Result<T>
where
F: FnMut(&str) -> io::Result<T>,
{
match addr.to_socket_addrs() {
Ok(addrs) => {
let mut last_err = None;
let mut encoded = String::new();
for addr in addrs {
write!(encoded, "{}", &addr).unwrap();
match f(&encoded) {
Ok(val) => return Ok(val),
Err(err) => {
last_err = Some(err);
encoded.clear();
}
}
}

match last_err {
Some(err) => Err(err),
None => Err(io::Error::NO_ADDRESSES),
}
}
Err(err) => match err.get_ref().and_then(|e| e.downcast_ref::<NonIpSockAddr>()) {
Some(NonIpSockAddr { host }) => f(host),
None => Err(err),
},
}
}

Expand All @@ -86,17 +120,18 @@ fn addr_to_sockaddr(addr: Option<&str>) -> io::Result<SocketAddr> {
}

impl TcpStream {
pub fn connect(addr: io::Result<&SocketAddr>) -> io::Result<TcpStream> {
let addr = io_err_to_addr(addr)?;
let (fd, local_addr, peer_addr) = usercalls::connect_stream(&addr)?;
Ok(TcpStream { inner: Socket::new(fd, local_addr), peer_addr: Some(peer_addr) })
pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> {
each_addr(addr, |addr| {
let (fd, local_addr, peer_addr) = usercalls::connect_stream(addr)?;
Ok(TcpStream { inner: Socket::new(fd, local_addr), peer_addr: Some(peer_addr) })
})
}

pub fn connect_timeout(addr: &SocketAddr, dur: Duration) -> io::Result<TcpStream> {
if dur == Duration::default() {
return Err(io::Error::ZERO_TIMEOUT);
}
Self::connect(Ok(addr)) // FIXME: ignoring timeout
Self::connect(addr) // FIXME: ignoring timeout
}

pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
Expand Down Expand Up @@ -247,10 +282,11 @@ impl fmt::Debug for TcpListener {
}

impl TcpListener {
pub fn bind(addr: io::Result<&SocketAddr>) -> io::Result<TcpListener> {
let addr = io_err_to_addr(addr)?;
let (fd, local_addr) = usercalls::bind_stream(&addr)?;
Ok(TcpListener { inner: Socket::new(fd, local_addr) })
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<TcpListener> {
each_addr(addr, |addr| {
let (fd, local_addr) = usercalls::bind_stream(addr)?;
Ok(TcpListener { inner: Socket::new(fd, local_addr) })
})
}

pub fn socket_addr(&self) -> io::Result<SocketAddr> {
Expand Down Expand Up @@ -316,7 +352,7 @@ impl FromInner<Socket> for TcpListener {
pub struct UdpSocket(!);

impl UdpSocket {
pub fn bind(_: io::Result<&SocketAddr>) -> io::Result<UdpSocket> {
pub fn bind<A: ToSocketAddrs>(_: A) -> io::Result<UdpSocket> {
unsupported()
}

Expand Down Expand Up @@ -436,7 +472,7 @@ impl UdpSocket {
self.0
}

pub fn connect(&self, _: io::Result<&SocketAddr>) -> io::Result<()> {
pub fn connect<A: ToSocketAddrs>(&self, _: A) -> io::Result<()> {
self.0
}
}
Expand Down
Loading
Loading