Skip to content

Commit 23c29c0

Browse files
Robert Ruschfacebook-github-bot
authored andcommitted
Fix up failing tests by adding Tcp::Localhost mode (meta-pytorch#1479)
Summary: This diff adds support for tls localhost to fix T232884876. I go into detail in https://fb.workplace.com/groups/1434680487636162/posts/1441470780290466/, but basically, the hostname can't be bound to for certain environments. This generally takes the approach discussed with mariusae, though I ended up doing localhost instead of unspecified. This is because the tests were already set up with localhost; it was just being ignored. This is a bit of a short term fix, as I hope to follow up by allowing the hostname to get sent to workers while using unspecified locally. DNS is "more correctly" resolved on the worker machines, due to proxying and load balancing questions, but that's a bigger lift. Reviewed By: mariusae Differential Revision: D84094699
1 parent 052e78f commit 23c29c0

File tree

14 files changed

+121
-50
lines changed

14 files changed

+121
-50
lines changed

hyper/src/commands/serve.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::time::Duration;
1010

1111
use hyperactor::channel::ChannelAddr;
1212
use hyperactor::channel::ChannelTransport;
13+
use hyperactor::channel::TcpMode;
1314
use hyperactor_multiprocess::system::System;
1415

1516
// The commands in the demo spawn temporary actors the join a system.
@@ -27,7 +28,9 @@ pub struct ServeCommand {
2728

2829
impl ServeCommand {
2930
pub async fn run(self) -> anyhow::Result<()> {
30-
let addr = self.addr.unwrap_or(ChannelAddr::any(ChannelTransport::Tcp));
31+
let addr = self
32+
.addr
33+
.unwrap_or(ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)));
3134
let handle = System::serve(addr, LONG_DURATION, LONG_DURATION).await?;
3235
eprintln!("serve: {}", handle.local_addr());
3336
handle.await;

hyperactor/benches/main.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use hyperactor::channel;
2222
use hyperactor::channel::ChannelAddr;
2323
use hyperactor::channel::ChannelTransport;
2424
use hyperactor::channel::Rx;
25+
use hyperactor::channel::TcpMode;
2526
use hyperactor::channel::Tx;
2627
use hyperactor::channel::dial;
2728
use hyperactor::channel::serve;
@@ -66,7 +67,7 @@ impl Message {
6667
fn bench_message_sizes(c: &mut Criterion) {
6768
let transports = vec![
6869
("local", ChannelTransport::Local),
69-
("tcp", ChannelTransport::Tcp),
70+
("tcp", ChannelTransport::Tcp(TcpMode::Hostname)),
7071
("unix", ChannelTransport::Unix),
7172
];
7273

@@ -108,7 +109,7 @@ fn bench_message_rates(c: &mut Criterion) {
108109

109110
let transports = vec![
110111
("local", ChannelTransport::Local),
111-
("tcp", ChannelTransport::Tcp),
112+
("tcp", ChannelTransport::Tcp(TcpMode::Hostname)),
112113
("unix", ChannelTransport::Unix),
113114
//TODO Add TLS once it is able to run in Sandcastle
114115
];

hyperactor/src/channel.rs

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
use core::net::SocketAddr;
1515
use std::fmt;
1616
use std::net::IpAddr;
17+
use std::net::Ipv4Addr;
18+
use std::net::Ipv6Addr;
1719
#[cfg(target_os = "linux")]
1820
use std::os::linux::net::SocketAddrExt;
1921
use std::str::FromStr;
@@ -236,6 +238,26 @@ impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
236238
}
237239
}
238240

241+
/// The hostname to use for TLS connections.
242+
#[derive(
243+
Clone,
244+
Debug,
245+
PartialEq,
246+
Eq,
247+
Hash,
248+
Serialize,
249+
Deserialize,
250+
strum::EnumIter,
251+
strum::Display,
252+
strum::EnumString
253+
)]
254+
pub enum TcpMode {
255+
/// Use localhost/loopback for the connection.
256+
Localhost,
257+
/// Use host domain name for the connection.
258+
Hostname,
259+
}
260+
239261
/// The hostname to use for TLS connections.
240262
#[derive(
241263
Clone,
@@ -315,7 +337,7 @@ impl fmt::Display for MetaTlsAddr {
315337
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Named)]
316338
pub enum ChannelTransport {
317339
/// Transport over a TCP connection.
318-
Tcp,
340+
Tcp(TcpMode),
319341

320342
/// Transport over a TCP connection with TLS support within Meta
321343
MetaTls(TlsMode),
@@ -333,7 +355,7 @@ pub enum ChannelTransport {
333355
impl fmt::Display for ChannelTransport {
334356
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335357
match self {
336-
Self::Tcp => write!(f, "tcp"),
358+
Self::Tcp(mode) => write!(f, "tcp({:?})", mode),
337359
Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
338360
Self::Local => write!(f, "local"),
339361
Self::Sim(transport) => write!(f, "sim({})", transport),
@@ -358,7 +380,13 @@ impl FromStr for ChannelTransport {
358380
}
359381

360382
match s {
361-
"tcp" => Ok(ChannelTransport::Tcp),
383+
// Default to TcpMode::Hostname, if the mode isn't set
384+
"tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)),
385+
s if s.starts_with("tcp(") => {
386+
let inner = &s["tcp(".len()..s.len() - 1];
387+
let mode = inner.parse()?;
388+
Ok(ChannelTransport::Tcp(mode))
389+
}
362390
"local" => Ok(ChannelTransport::Local),
363391
"unix" => Ok(ChannelTransport::Unix),
364392
s if s.starts_with("metatls(") && s.ends_with(")") => {
@@ -373,9 +401,10 @@ impl FromStr for ChannelTransport {
373401

374402
impl ChannelTransport {
375403
/// All known channel transports.
376-
pub fn all() -> [ChannelTransport; 3] {
404+
pub fn all() -> [ChannelTransport; 4] {
377405
[
378-
ChannelTransport::Tcp,
406+
ChannelTransport::Tcp(TcpMode::Localhost),
407+
ChannelTransport::Tcp(TcpMode::Hostname),
379408
ChannelTransport::Local,
380409
ChannelTransport::Unix,
381410
// TODO add MetaTls (T208303369)
@@ -392,7 +421,7 @@ impl ChannelTransport {
392421
/// Returns true if this transport type represents a remote channel.
393422
pub fn is_remote(&self) -> bool {
394423
match self {
395-
ChannelTransport::Tcp => true,
424+
ChannelTransport::Tcp(_) => true,
396425
ChannelTransport::MetaTls(_) => true,
397426
ChannelTransport::Local => false,
398427
ChannelTransport::Sim(_) => false,
@@ -502,18 +531,27 @@ impl ChannelAddr {
502531
/// servers to "any" address.
503532
pub fn any(transport: ChannelTransport) -> Self {
504533
match transport {
505-
ChannelTransport::Tcp => {
506-
let ip = hostname::get()
507-
.ok()
508-
.and_then(|hostname| {
509-
// TODO: Avoid using DNS directly once we figure out a good extensibility story here
510-
hostname.to_str().and_then(|hostname_str| {
511-
dns_lookup::lookup_host(hostname_str)
512-
.ok()
513-
.and_then(|addresses| addresses.first().cloned())
534+
ChannelTransport::Tcp(mode) => {
535+
let ip = match mode {
536+
TcpMode::Localhost => {
537+
// Try IPv6 first, fall back to IPv4 if the system doesn't support IPv6
538+
match std::net::TcpListener::bind((Ipv6Addr::LOCALHOST, 0)) {
539+
Ok(_) => IpAddr::V6(Ipv6Addr::LOCALHOST),
540+
Err(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
541+
}
542+
}
543+
TcpMode::Hostname => hostname::get()
544+
.ok()
545+
.and_then(|hostname| {
546+
// TODO: Avoid using DNS directly once we figure out a good extensibility story here
547+
hostname.to_str().and_then(|hostname_str| {
548+
dns_lookup::lookup_host(hostname_str)
549+
.ok()
550+
.and_then(|addresses| addresses.first().cloned())
551+
})
514552
})
515-
})
516-
.unwrap_or_else(|| IpAddr::from_str("::1").unwrap());
553+
.expect("Failed to resolve hostname to IP address"),
554+
};
517555
Self::Tcp(SocketAddr::new(ip, 0))
518556
}
519557
ChannelTransport::MetaTls(mode) => {
@@ -542,7 +580,13 @@ impl ChannelAddr {
542580
/// The transport used by this address.
543581
pub fn transport(&self) -> ChannelTransport {
544582
match self {
545-
Self::Tcp(_) => ChannelTransport::Tcp,
583+
Self::Tcp(addr) => {
584+
if addr.ip().is_loopback() {
585+
ChannelTransport::Tcp(TcpMode::Localhost)
586+
} else {
587+
ChannelTransport::Tcp(TcpMode::Hostname)
588+
}
589+
}
546590
Self::MetaTls(addr) => match addr {
547591
MetaTlsAddr::Host { hostname, .. } => match hostname.parse::<IpAddr>() {
548592
Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),

hyperactor/src/channel/net.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1799,7 +1799,12 @@ async fn join_nonempty<T: 'static>(set: &mut JoinSet<T>) -> Result<T, JoinError>
17991799
/// Tells whether the address is a 'net' address. These currently have different semantics
18001800
/// from local transports.
18011801
pub fn is_net_addr(addr: &ChannelAddr) -> bool {
1802-
[ChannelTransport::Tcp, ChannelTransport::Unix].contains(&addr.transport())
1802+
match addr.transport() {
1803+
// TODO Metatls?
1804+
ChannelTransport::Tcp(_) => true,
1805+
ChannelTransport::Unix => true,
1806+
_ => false,
1807+
}
18031808
}
18041809

18051810
pub(crate) mod unix {

hyperactor_mesh/src/alloc/remoteprocess.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use hyperactor::channel::ChannelRx;
2626
use hyperactor::channel::ChannelTransport;
2727
use hyperactor::channel::ChannelTx;
2828
use hyperactor::channel::Rx;
29+
use hyperactor::channel::TcpMode;
2930
use hyperactor::channel::Tx;
3031
use hyperactor::channel::TxStatus;
3132
use hyperactor::clock;
@@ -768,7 +769,10 @@ impl RemoteProcessAlloc {
768769
ChannelTransport::MetaTls(_) => {
769770
format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
770771
}
771-
ChannelTransport::Tcp => {
772+
ChannelTransport::Tcp(TcpMode::Localhost) => {
773+
format!("tcp![::1]:{}", self.remote_allocator_port)
774+
}
775+
ChannelTransport::Tcp(TcpMode::Hostname) => {
772776
format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
773777
}
774778
// Used only for testing.

hyperactor_mesh/src/bootstrap.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2043,6 +2043,7 @@ mod tests {
20432043
use hyperactor::WorldId;
20442044
use hyperactor::channel::ChannelAddr;
20452045
use hyperactor::channel::ChannelTransport;
2046+
use hyperactor::channel::TcpMode;
20462047
use hyperactor::clock::RealClock;
20472048
use hyperactor::context::Mailbox as _;
20482049
use hyperactor::host::ProcHandle;
@@ -2073,7 +2074,7 @@ mod tests {
20732074
Bootstrap::default(),
20742075
Bootstrap::Proc {
20752076
proc_id: id!(foo[0]),
2076-
backend_addr: ChannelAddr::any(ChannelTransport::Tcp),
2077+
backend_addr: ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
20772078
callback_addr: ChannelAddr::any(ChannelTransport::Unix),
20782079
config: None,
20792080
},

hyperactor_multiprocess/src/proc_actor.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use hyperactor::actor::Referable;
3434
use hyperactor::actor::remote::Remote;
3535
use hyperactor::channel;
3636
use hyperactor::channel::ChannelAddr;
37+
use hyperactor::channel::TcpMode;
3738
use hyperactor::clock::Clock;
3839
use hyperactor::clock::ClockKind;
3940
use hyperactor::context;
@@ -1376,7 +1377,7 @@ mod tests {
13761377

13771378
// Serve a system.
13781379
let server_handle = System::serve(
1379-
ChannelAddr::any(ChannelTransport::Tcp),
1380+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
13801381
Duration::from_secs(120),
13811382
Duration::from_secs(120),
13821383
)
@@ -1395,7 +1396,7 @@ mod tests {
13951396
));
13961397

13971398
// Construct a proc forwarder in terms of the system sender.
1398-
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp);
1399+
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname));
13991400
let proc_forwarder =
14001401
BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
14011402

@@ -1422,7 +1423,7 @@ mod tests {
14221423
let _proc_actor_1 = ProcActor::bootstrap_for_proc(
14231424
proc_1.clone(),
14241425
world_id.clone(),
1425-
ChannelAddr::any(ChannelTransport::Tcp),
1426+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
14261427
server_handle.local_addr().clone(),
14271428
sup_ref.clone(),
14281429
Duration::from_secs(120),
@@ -1497,7 +1498,7 @@ mod tests {
14971498

14981499
// Serve a system.
14991500
let server_handle = System::serve(
1500-
ChannelAddr::any(ChannelTransport::Tcp),
1501+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
15011502
Duration::from_secs(120),
15021503
Duration::from_secs(120),
15031504
)
@@ -1518,7 +1519,7 @@ mod tests {
15181519
));
15191520

15201521
// Construct a proc forwarder in terms of the system sender.
1521-
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp);
1522+
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname));
15221523
let proc_forwarder =
15231524
BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
15241525

@@ -1545,7 +1546,7 @@ mod tests {
15451546
let _proc_actor_1 = ProcActor::bootstrap_for_proc(
15461547
proc_1.clone(),
15471548
world_id.clone(),
1548-
ChannelAddr::any(ChannelTransport::Tcp),
1549+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
15491550
server_handle.local_addr().clone(),
15501551
sup_ref.clone(),
15511552
Duration::from_secs(120),
@@ -1651,7 +1652,7 @@ mod tests {
16511652
#[tokio::test]
16521653
async fn test_update_address_book_cache() {
16531654
let server_handle = System::serve(
1654-
ChannelAddr::any(ChannelTransport::Tcp),
1655+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
16551656
Duration::from_secs(2), // supervision update timeout
16561657
Duration::from_secs(2), // duration to evict an unhealthy world
16571658
)
@@ -1705,7 +1706,7 @@ mod tests {
17051706
actor_id: &ActorId,
17061707
system_addr: &ChannelAddr,
17071708
) -> (ActorRef<PingPongActor>, ActorRef<ProcActor>) {
1708-
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp);
1709+
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname));
17091710
let bootstrap = ProcActor::bootstrap(
17101711
actor_id.proc_id().clone(),
17111712
actor_id

hyperactor_multiprocess/src/system.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ mod tests {
176176
use hyperactor::WorldId;
177177
use hyperactor::channel::ChannelAddr;
178178
use hyperactor::channel::ChannelTransport;
179+
use hyperactor::channel::TcpMode;
179180
use hyperactor::clock::Clock;
180181
use hyperactor::clock::RealClock;
181182
use hyperactor_telemetry::env::execution_id;
@@ -825,7 +826,7 @@ mod tests {
825826
#[tokio::test]
826827
async fn test_channel_dial_count() {
827828
let system_handle = System::serve(
828-
ChannelAddr::any(ChannelTransport::Tcp),
829+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
829830
Duration::from_secs(10),
830831
Duration::from_secs(10),
831832
)

hyperactor_multiprocess/src/system_actor.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,6 +1849,7 @@ mod tests {
18491849
use hyperactor::channel;
18501850
use hyperactor::channel::ChannelTransport;
18511851
use hyperactor::channel::Rx;
1852+
use hyperactor::channel::TcpMode;
18521853
use hyperactor::clock::Clock;
18531854
use hyperactor::clock::RealClock;
18541855
use hyperactor::data::Serialized;
@@ -2194,7 +2195,7 @@ mod tests {
21942195
// Serve a system. Undeliverable messages encountered by the
21952196
// mailbox server are returned to the system actor.
21962197
let server_handle = System::serve(
2197-
ChannelAddr::any(ChannelTransport::Tcp),
2198+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
21982199
Duration::from_secs(2), // supervision update timeout
21992200
Duration::from_secs(2), // duration to evict an unhealthy world
22002201
)
@@ -2255,7 +2256,7 @@ mod tests {
22552256
let _proc_actor_0 = ProcActor::bootstrap_for_proc(
22562257
proc_0.clone(),
22572258
world_id.clone(),
2258-
ChannelAddr::any(ChannelTransport::Tcp),
2259+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
22592260
server_handle.local_addr().clone(),
22602261
sup_ref.clone(),
22612262
Duration::from_millis(300), // supervision update interval
@@ -2272,7 +2273,7 @@ mod tests {
22722273
let proc_actor_1 = ProcActor::bootstrap_for_proc(
22732274
proc_1.clone(),
22742275
world_id.clone(),
2275-
ChannelAddr::any(ChannelTransport::Tcp),
2276+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
22762277
server_handle.local_addr().clone(),
22772278
sup_ref.clone(),
22782279
Duration::from_millis(300), // supervision update interval
@@ -2348,7 +2349,7 @@ mod tests {
23482349
#[tokio::test]
23492350
async fn test_stop_fast() -> Result<()> {
23502351
let server_handle = System::serve(
2351-
ChannelAddr::any(ChannelTransport::Tcp),
2352+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
23522353
Duration::from_secs(2), // supervision update timeout
23532354
Duration::from_secs(2), // duration to evict an unhealthy world
23542355
)

0 commit comments

Comments
 (0)