Skip to content

Commit c2d0771

Browse files
committed
Introduce multiple addresses for connection
1 parent 91cc177 commit c2d0771

File tree

2 files changed

+74
-28
lines changed

2 files changed

+74
-28
lines changed

src/cli.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,28 @@ pub struct Args {
2424
#[arg(long, value_name = "path to script file")]
2525
pub script: Vec<String>,
2626

27-
/// TypeDB address to connect to. If using TLS encryption, this must start with "https://"
28-
#[arg(long, value_name = ADDRESS_VALUE_NAME)]
27+
/// TypeDB address to connect to (host:port). If using TLS encryption, this must start with "https://".
28+
#[arg(long, value_name = ADDRESS_VALUE_NAME, conflicts_with_all = ["addresses", "address_translation"])]
2929
pub address: Option<String>,
3030

31+
/// A comma-separated list of TypeDB replica addresses of a single cluster to connect to.
32+
#[arg(long, value_name = "host1:port1,host2:port2", conflicts_with_all = ["address", "address_translation"])]
33+
pub addresses: Option<String>,
34+
35+
/// A comma-separated list of public=private address pairs. Public addresses are the user-facing
36+
/// addresses of the replicas, and private addresses are the addresses used for the server-side
37+
/// connection between replicas.
38+
#[arg(long, value_name = "public=private,...", conflicts_with_all = ["address", "addresses"])]
39+
pub address_translation: Option<String>,
40+
41+
/// If used in a cluster environment, disables attempts to redirect requests to server replicas,
42+
/// limiting Console to communicate only with the single address specified in the `address`
43+
/// argument.
44+
/// Use for administrative / debug purposes to test a specific replica only: this option will
45+
/// lower the success rate of Console's operations in production.
46+
#[arg(long = "replication-disabled", default_value = "false")]
47+
pub replication_disabled: bool,
48+
3149
/// Username for authentication
3250
#[arg(long, value_name = USERNAME_VALUE_NAME)]
3351
pub username: Option<String>,
@@ -48,8 +66,8 @@ pub struct Args {
4866

4967
/// Disable error reporting. Error reporting helps TypeDB improve by reporting
5068
/// errors and crashes to the development team.
51-
#[arg(long = "diagnostics-disable", default_value = "false")]
52-
pub diagnostics_disable: bool,
69+
#[arg(long = "diagnostics-disabled", default_value = "false")]
70+
pub diagnostics_disabled: bool,
5371

5472
/// Print the Console binary version
5573
#[arg(long = "version")]

src/main.rs

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use std::{
1717
rc::Rc,
1818
sync::Arc,
1919
};
20-
20+
use std::collections::HashMap;
2121
use clap::Parser;
2222
use home::home_dir;
2323
use rustyline::error::ReadlineError;
@@ -126,13 +126,19 @@ fn main() {
126126
println!("{}", VERSION);
127127
exit(ExitCode::Success as i32);
128128
}
129-
let address = match args.address {
130-
Some(address) => address,
131-
None => {
132-
println_error!("missing server address ('{}').", format_argument!("--address <{ADDRESS_VALUE_NAME}>"));
133-
exit(ExitCode::UserInputError as i32);
134-
}
135-
};
129+
let address_info = parse_addresses(&args);
130+
if !args.tls_disabled && !address_info.only_https {
131+
println_error!(
132+
"\
133+
TLS connections can only be enabled when connecting to HTTPS endpoints. \
134+
For example, using 'https://<ip>:port'.\n\
135+
Please modify the address or disable TLS ('{}'). {}\
136+
",
137+
format_argument!("--tls-disabled"),
138+
format_warning!("WARNING: this will send passwords over plaintext!"),
139+
);
140+
exit(ExitCode::UserInputError as i32);
141+
}
136142
let username = match args.username {
137143
Some(username) => username,
138144
None => {
@@ -146,28 +152,16 @@ fn main() {
146152
if args.password.is_none() {
147153
args.password = Some(LineReaderHidden::new().readline(&format!("password for '{username}': ")));
148154
}
149-
if !args.diagnostics_disable {
155+
if !args.diagnostics_disabled {
150156
init_diagnostics()
151157
}
152-
if !args.tls_disabled && !address.starts_with("https:") {
153-
println_error!(
154-
"\
155-
TLS connections can only be enabled when connecting to HTTPS endpoints. \
156-
For example, using 'https://<ip>:port'.\n\
157-
Please modify the address or disable TLS ('{}'). {}\
158-
",
159-
format_argument!("--tls-disabled"),
160-
format_warning!("WARNING: this will send passwords over plaintext!"),
161-
);
162-
exit(ExitCode::UserInputError as i32);
163-
}
164158
let tls_root_ca_path = args.tls_root_ca.as_ref().map(|value| Path::new(value));
165-
166159
let runtime = BackgroundRuntime::new();
160+
let driver_options = DriverOptions::new().use_replication(!args.replication_disabled).tls_enabled(!args.tls_disabled).tls_root_ca(tls_root_ca_path).unwrap();
167161
let driver = match runtime.run(TypeDBDriver::new(
168-
Addresses::try_from_address_str(&address).unwrap(),
162+
address_info.addresses,
169163
Credentials::new(&username, args.password.as_ref().unwrap()),
170-
DriverOptions::new().tls_enabled(!args.tls_disabled).tls_root_ca(tls_root_ca_path).unwrap(),
164+
driver_options,
171165
)) {
172166
Ok(driver) => Arc::new(driver),
173167
Err(err) => {
@@ -485,6 +479,40 @@ fn transaction_type_str(transaction_type: TransactionType) -> &'static str {
485479
}
486480
}
487481

482+
struct AddressInfo {
483+
only_https: bool,
484+
addresses: Addresses,
485+
}
486+
487+
fn parse_addresses(args: &Args) -> AddressInfo {
488+
if let Some(address) = &args.address {
489+
AddressInfo {only_https: is_https_address(address), addresses: Addresses::try_from_address_str(address).unwrap() }
490+
} else if let Some(addresses) = &args.addresses {
491+
let split = addresses.split(',').map(str::to_string).collect::<Vec<_>>();
492+
println!("Split: {split:?}");
493+
let only_https = split.iter().all(|address| is_https_address(address));
494+
AddressInfo {only_https, addresses: Addresses::try_from_addresses_str(split).unwrap() }
495+
} else if let Some(translation) = &args.address_translation {
496+
let mut map = HashMap::new();
497+
let mut only_https = true;
498+
for pair in translation.split(',') {
499+
let (public_address, private_address) = pair
500+
.split_once('=')
501+
.unwrap_or_else(|| panic!("Invalid address pair: {pair}. Must be of form public=private"));
502+
only_https = only_https && is_https_address(public_address);
503+
map.insert(public_address.to_string(), private_address.to_string());
504+
}
505+
println!("Translation map:: {map:?}");
506+
AddressInfo {only_https, addresses: Addresses::try_from_translation_str(map).unwrap() }
507+
} else {
508+
panic!("At least one of --address, --addresses, or --address-translation must be provided.");
509+
}
510+
}
511+
512+
fn is_https_address(address: &str) -> bool {
513+
address.starts_with("https:")
514+
}
515+
488516
fn init_diagnostics() {
489517
let _ = sentry::init((
490518
DIAGNOSTICS_REPORTING_URI,

0 commit comments

Comments
 (0)