Skip to content

Commit 8691bcf

Browse files
Robert Ruschfacebook-github-bot
authored andcommitted
Fix up failing tests by adding Tcp::Localhost mode (meta-pytorch#1479)
Summary: This diff adds support for tls localhost to fix T232884876. I go into detail in https://fb.workplace.com/groups/1434680487636162/posts/1441470780290466/, but basically, the hostname can't be bound to for certain environments. This generally takes the approach discussed with mariusae, though I ended up doing localhost instead of unspecified. This is because the tests were already set up with localhost; it was just being ignored. This is a bit of a short term fix, as I hope to follow up by allowing the hostname to get sent to workers while using unspecified locally. DNS is "more correctly" resolved on the worker machines, due to proxying and load balancing questions, but that's a bigger lift. Reviewed By: mariusae Differential Revision: D84094699
1 parent 138ad1f commit 8691bcf

File tree

14 files changed

+114
-50
lines changed

14 files changed

+114
-50
lines changed

hyper/src/commands/serve.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::time::Duration;
1010

1111
use hyperactor::channel::ChannelAddr;
1212
use hyperactor::channel::ChannelTransport;
13+
use hyperactor::channel::TcpMode;
1314
use hyperactor_multiprocess::system::System;
1415

1516
// The commands in the demo spawn temporary actors the join a system.
@@ -27,7 +28,9 @@ pub struct ServeCommand {
2728

2829
impl ServeCommand {
2930
pub async fn run(self) -> anyhow::Result<()> {
30-
let addr = self.addr.unwrap_or(ChannelAddr::any(ChannelTransport::Tcp));
31+
let addr = self
32+
.addr
33+
.unwrap_or(ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)));
3134
let handle = System::serve(addr, LONG_DURATION, LONG_DURATION).await?;
3235
eprintln!("serve: {}", handle.local_addr());
3336
handle.await;

hyperactor/benches/main.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use hyperactor::channel;
2222
use hyperactor::channel::ChannelAddr;
2323
use hyperactor::channel::ChannelTransport;
2424
use hyperactor::channel::Rx;
25+
use hyperactor::channel::TcpMode;
2526
use hyperactor::channel::Tx;
2627
use hyperactor::channel::dial;
2728
use hyperactor::channel::serve;
@@ -66,7 +67,7 @@ impl Message {
6667
fn bench_message_sizes(c: &mut Criterion) {
6768
let transports = vec![
6869
("local", ChannelTransport::Local),
69-
("tcp", ChannelTransport::Tcp),
70+
("tcp", ChannelTransport::Tcp(TcpMode::Hostname)),
7071
("unix", ChannelTransport::Unix),
7172
];
7273

@@ -108,7 +109,7 @@ fn bench_message_rates(c: &mut Criterion) {
108109

109110
let transports = vec![
110111
("local", ChannelTransport::Local),
111-
("tcp", ChannelTransport::Tcp),
112+
("tcp", ChannelTransport::Tcp(TcpMode::Hostname)),
112113
("unix", ChannelTransport::Unix),
113114
//TODO Add TLS once it is able to run in Sandcastle
114115
];

hyperactor/src/channel.rs

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
use core::net::SocketAddr;
1515
use std::fmt;
1616
use std::net::IpAddr;
17+
use std::net::Ipv6Addr;
1718
#[cfg(target_os = "linux")]
1819
use std::os::linux::net::SocketAddrExt;
1920
use std::str::FromStr;
@@ -236,6 +237,26 @@ impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
236237
}
237238
}
238239

240+
/// The hostname to use for TLS connections.
241+
#[derive(
242+
Clone,
243+
Debug,
244+
PartialEq,
245+
Eq,
246+
Hash,
247+
Serialize,
248+
Deserialize,
249+
strum::EnumIter,
250+
strum::Display,
251+
strum::EnumString
252+
)]
253+
pub enum TcpMode {
254+
/// Use localhost/loopback for the connection.
255+
Localhost,
256+
/// Use host domain name for the connection.
257+
Hostname,
258+
}
259+
239260
/// The hostname to use for TLS connections.
240261
#[derive(
241262
Clone,
@@ -315,7 +336,7 @@ impl fmt::Display for MetaTlsAddr {
315336
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Named)]
316337
pub enum ChannelTransport {
317338
/// Transport over a TCP connection.
318-
Tcp,
339+
Tcp(TcpMode),
319340

320341
/// Transport over a TCP connection with TLS support within Meta
321342
MetaTls(TlsMode),
@@ -333,7 +354,7 @@ pub enum ChannelTransport {
333354
impl fmt::Display for ChannelTransport {
334355
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335356
match self {
336-
Self::Tcp => write!(f, "tcp"),
357+
Self::Tcp(mode) => write!(f, "tcp({:?})", mode),
337358
Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
338359
Self::Local => write!(f, "local"),
339360
Self::Sim(transport) => write!(f, "sim({})", transport),
@@ -358,7 +379,13 @@ impl FromStr for ChannelTransport {
358379
}
359380

360381
match s {
361-
"tcp" => Ok(ChannelTransport::Tcp),
382+
// Default to TcpMode::Hostname, if the mode isn't set
383+
"tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)),
384+
s if s.starts_with("tcp(") => {
385+
let inner = &s["tcp(".len()..s.len() - 1];
386+
let mode = inner.parse()?;
387+
Ok(ChannelTransport::Tcp(mode))
388+
}
362389
"local" => Ok(ChannelTransport::Local),
363390
"unix" => Ok(ChannelTransport::Unix),
364391
s if s.starts_with("metatls(") && s.ends_with(")") => {
@@ -373,9 +400,10 @@ impl FromStr for ChannelTransport {
373400

374401
impl ChannelTransport {
375402
/// All known channel transports.
376-
pub fn all() -> [ChannelTransport; 3] {
403+
pub fn all() -> [ChannelTransport; 4] {
377404
[
378-
ChannelTransport::Tcp,
405+
ChannelTransport::Tcp(TcpMode::Localhost),
406+
ChannelTransport::Tcp(TcpMode::Hostname),
379407
ChannelTransport::Local,
380408
ChannelTransport::Unix,
381409
// TODO add MetaTls (T208303369)
@@ -392,7 +420,7 @@ impl ChannelTransport {
392420
/// Returns true if this transport type represents a remote channel.
393421
pub fn is_remote(&self) -> bool {
394422
match self {
395-
ChannelTransport::Tcp => true,
423+
ChannelTransport::Tcp(_) => true,
396424
ChannelTransport::MetaTls(_) => true,
397425
ChannelTransport::Local => false,
398426
ChannelTransport::Sim(_) => false,
@@ -502,18 +530,21 @@ impl ChannelAddr {
502530
/// servers to "any" address.
503531
pub fn any(transport: ChannelTransport) -> Self {
504532
match transport {
505-
ChannelTransport::Tcp => {
506-
let ip = hostname::get()
507-
.ok()
508-
.and_then(|hostname| {
509-
// TODO: Avoid using DNS directly once we figure out a good extensibility story here
510-
hostname.to_str().and_then(|hostname_str| {
511-
dns_lookup::lookup_host(hostname_str)
512-
.ok()
513-
.and_then(|addresses| addresses.first().cloned())
533+
ChannelTransport::Tcp(mode) => {
534+
let ip = match mode {
535+
TcpMode::Localhost => IpAddr::V6(Ipv6Addr::LOCALHOST),
536+
TcpMode::Hostname => hostname::get()
537+
.ok()
538+
.and_then(|hostname| {
539+
// TODO: Avoid using DNS directly once we figure out a good extensibility story here
540+
hostname.to_str().and_then(|hostname_str| {
541+
dns_lookup::lookup_host(hostname_str)
542+
.ok()
543+
.and_then(|addresses| addresses.first().cloned())
544+
})
514545
})
515-
})
516-
.unwrap_or_else(|| IpAddr::from_str("::1").unwrap());
546+
.expect("Failed to resolve hostname to IP address"),
547+
};
517548
Self::Tcp(SocketAddr::new(ip, 0))
518549
}
519550
ChannelTransport::MetaTls(mode) => {
@@ -542,7 +573,13 @@ impl ChannelAddr {
542573
/// The transport used by this address.
543574
pub fn transport(&self) -> ChannelTransport {
544575
match self {
545-
Self::Tcp(_) => ChannelTransport::Tcp,
576+
Self::Tcp(addr) => {
577+
if addr.ip().is_loopback() {
578+
ChannelTransport::Tcp(TcpMode::Localhost)
579+
} else {
580+
ChannelTransport::Tcp(TcpMode::Hostname)
581+
}
582+
}
546583
Self::MetaTls(addr) => match addr {
547584
MetaTlsAddr::Host { hostname, .. } => match hostname.parse::<IpAddr>() {
548585
Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),

hyperactor/src/channel/net.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1799,7 +1799,12 @@ async fn join_nonempty<T: 'static>(set: &mut JoinSet<T>) -> Result<T, JoinError>
17991799
/// Tells whether the address is a 'net' address. These currently have different semantics
18001800
/// from local transports.
18011801
pub fn is_net_addr(addr: &ChannelAddr) -> bool {
1802-
[ChannelTransport::Tcp, ChannelTransport::Unix].contains(&addr.transport())
1802+
match addr.transport() {
1803+
// TODO Metatls?
1804+
ChannelTransport::Tcp(_) => true,
1805+
ChannelTransport::Unix => true,
1806+
_ => false,
1807+
}
18031808
}
18041809

18051810
pub(crate) mod unix {

hyperactor_mesh/src/alloc/remoteprocess.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use hyperactor::channel::ChannelRx;
2626
use hyperactor::channel::ChannelTransport;
2727
use hyperactor::channel::ChannelTx;
2828
use hyperactor::channel::Rx;
29+
use hyperactor::channel::TcpMode;
2930
use hyperactor::channel::Tx;
3031
use hyperactor::channel::TxStatus;
3132
use hyperactor::clock;
@@ -768,7 +769,10 @@ impl RemoteProcessAlloc {
768769
ChannelTransport::MetaTls(_) => {
769770
format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
770771
}
771-
ChannelTransport::Tcp => {
772+
ChannelTransport::Tcp(TcpMode::Localhost) => {
773+
format!("tcp![::1]:{}", self.remote_allocator_port)
774+
}
775+
ChannelTransport::Tcp(TcpMode::Hostname) => {
772776
format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
773777
}
774778
// Used only for testing.

hyperactor_mesh/src/bootstrap.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2030,6 +2030,7 @@ mod tests {
20302030
use hyperactor::WorldId;
20312031
use hyperactor::channel::ChannelAddr;
20322032
use hyperactor::channel::ChannelTransport;
2033+
use hyperactor::channel::TcpMode;
20332034
use hyperactor::clock::RealClock;
20342035
use hyperactor::context::Mailbox as _;
20352036
use hyperactor::host::ProcHandle;
@@ -2060,7 +2061,7 @@ mod tests {
20602061
Bootstrap::default(),
20612062
Bootstrap::Proc {
20622063
proc_id: id!(foo[0]),
2063-
backend_addr: ChannelAddr::any(ChannelTransport::Tcp),
2064+
backend_addr: ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
20642065
callback_addr: ChannelAddr::any(ChannelTransport::Unix),
20652066
config: None,
20662067
},

hyperactor_multiprocess/src/proc_actor.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use hyperactor::actor::Referable;
3434
use hyperactor::actor::remote::Remote;
3535
use hyperactor::channel;
3636
use hyperactor::channel::ChannelAddr;
37+
use hyperactor::channel::TcpMode;
3738
use hyperactor::clock::Clock;
3839
use hyperactor::clock::ClockKind;
3940
use hyperactor::context;
@@ -1376,7 +1377,7 @@ mod tests {
13761377

13771378
// Serve a system.
13781379
let server_handle = System::serve(
1379-
ChannelAddr::any(ChannelTransport::Tcp),
1380+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
13801381
Duration::from_secs(120),
13811382
Duration::from_secs(120),
13821383
)
@@ -1395,7 +1396,7 @@ mod tests {
13951396
));
13961397

13971398
// Construct a proc forwarder in terms of the system sender.
1398-
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp);
1399+
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname));
13991400
let proc_forwarder =
14001401
BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
14011402

@@ -1422,7 +1423,7 @@ mod tests {
14221423
let _proc_actor_1 = ProcActor::bootstrap_for_proc(
14231424
proc_1.clone(),
14241425
world_id.clone(),
1425-
ChannelAddr::any(ChannelTransport::Tcp),
1426+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
14261427
server_handle.local_addr().clone(),
14271428
sup_ref.clone(),
14281429
Duration::from_secs(120),
@@ -1497,7 +1498,7 @@ mod tests {
14971498

14981499
// Serve a system.
14991500
let server_handle = System::serve(
1500-
ChannelAddr::any(ChannelTransport::Tcp),
1501+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
15011502
Duration::from_secs(120),
15021503
Duration::from_secs(120),
15031504
)
@@ -1518,7 +1519,7 @@ mod tests {
15181519
));
15191520

15201521
// Construct a proc forwarder in terms of the system sender.
1521-
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp);
1522+
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname));
15221523
let proc_forwarder =
15231524
BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
15241525

@@ -1545,7 +1546,7 @@ mod tests {
15451546
let _proc_actor_1 = ProcActor::bootstrap_for_proc(
15461547
proc_1.clone(),
15471548
world_id.clone(),
1548-
ChannelAddr::any(ChannelTransport::Tcp),
1549+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
15491550
server_handle.local_addr().clone(),
15501551
sup_ref.clone(),
15511552
Duration::from_secs(120),
@@ -1651,7 +1652,7 @@ mod tests {
16511652
#[tokio::test]
16521653
async fn test_update_address_book_cache() {
16531654
let server_handle = System::serve(
1654-
ChannelAddr::any(ChannelTransport::Tcp),
1655+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
16551656
Duration::from_secs(2), // supervision update timeout
16561657
Duration::from_secs(2), // duration to evict an unhealthy world
16571658
)
@@ -1705,7 +1706,7 @@ mod tests {
17051706
actor_id: &ActorId,
17061707
system_addr: &ChannelAddr,
17071708
) -> (ActorRef<PingPongActor>, ActorRef<ProcActor>) {
1708-
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp);
1709+
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname));
17091710
let bootstrap = ProcActor::bootstrap(
17101711
actor_id.proc_id().clone(),
17111712
actor_id

hyperactor_multiprocess/src/system.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ mod tests {
176176
use hyperactor::WorldId;
177177
use hyperactor::channel::ChannelAddr;
178178
use hyperactor::channel::ChannelTransport;
179+
use hyperactor::channel::TcpMode;
179180
use hyperactor::clock::Clock;
180181
use hyperactor::clock::RealClock;
181182
use hyperactor_telemetry::env::execution_id;
@@ -825,7 +826,7 @@ mod tests {
825826
#[tokio::test]
826827
async fn test_channel_dial_count() {
827828
let system_handle = System::serve(
828-
ChannelAddr::any(ChannelTransport::Tcp),
829+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
829830
Duration::from_secs(10),
830831
Duration::from_secs(10),
831832
)

hyperactor_multiprocess/src/system_actor.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,6 +1849,7 @@ mod tests {
18491849
use hyperactor::channel;
18501850
use hyperactor::channel::ChannelTransport;
18511851
use hyperactor::channel::Rx;
1852+
use hyperactor::channel::TcpMode;
18521853
use hyperactor::clock::Clock;
18531854
use hyperactor::clock::RealClock;
18541855
use hyperactor::data::Serialized;
@@ -2194,7 +2195,7 @@ mod tests {
21942195
// Serve a system. Undeliverable messages encountered by the
21952196
// mailbox server are returned to the system actor.
21962197
let server_handle = System::serve(
2197-
ChannelAddr::any(ChannelTransport::Tcp),
2198+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
21982199
Duration::from_secs(2), // supervision update timeout
21992200
Duration::from_secs(2), // duration to evict an unhealthy world
22002201
)
@@ -2255,7 +2256,7 @@ mod tests {
22552256
let _proc_actor_0 = ProcActor::bootstrap_for_proc(
22562257
proc_0.clone(),
22572258
world_id.clone(),
2258-
ChannelAddr::any(ChannelTransport::Tcp),
2259+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
22592260
server_handle.local_addr().clone(),
22602261
sup_ref.clone(),
22612262
Duration::from_millis(300), // supervision update interval
@@ -2272,7 +2273,7 @@ mod tests {
22722273
let proc_actor_1 = ProcActor::bootstrap_for_proc(
22732274
proc_1.clone(),
22742275
world_id.clone(),
2275-
ChannelAddr::any(ChannelTransport::Tcp),
2276+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
22762277
server_handle.local_addr().clone(),
22772278
sup_ref.clone(),
22782279
Duration::from_millis(300), // supervision update interval
@@ -2348,7 +2349,7 @@ mod tests {
23482349
#[tokio::test]
23492350
async fn test_stop_fast() -> Result<()> {
23502351
let server_handle = System::serve(
2351-
ChannelAddr::any(ChannelTransport::Tcp),
2352+
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
23522353
Duration::from_secs(2), // supervision update timeout
23532354
Duration::from_secs(2), // duration to evict an unhealthy world
23542355
)

0 commit comments

Comments
 (0)