Skip to content

Commit 1c72ac0

Browse files
committed
Introduce multiple addresses for connection
1 parent f4b59c4 commit 1c72ac0

File tree

5 files changed

+72
-14
lines changed

5 files changed

+72
-14
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ features = {}
4646

4747
[dependencies.typedb-driver]
4848
features = []
49-
rev = "59547fa99650030c449697b74fbc30ce63264104"
49+
rev = "f1bf3d9e327344fe7a67084c13da4eb0fb5ca2be"
5050
git = "https://github.com/typedb/typedb-driver"
5151
default-features = false
5252

dependencies/typedb/repositories.bzl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@ def typedb_dependencies():
1313

1414
def typedb_driver():
1515
# TODO: Return typedb
16+
# native.local_repository(
17+
# name = "typedb_driver",
18+
# path = "../typedb-driver",
19+
# )
1620
git_repository(
1721
name = "typedb_driver",
1822
remote = "https://github.com/farost/typedb-driver",
19-
commit = "59547fa99650030c449697b74fbc30ce63264104", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_driver
23+
commit = "f1bf3d9e327344fe7a67084c13da4eb0fb5ca2be", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_driver
2024
)
2125
# git_repository(
2226
# name = "typedb_driver",

src/cli.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,27 @@ pub struct Args {
2121
#[arg(long, value_name = "path to script file")]
2222
pub script: Vec<String>,
2323

24-
/// TypeDB address to connect to. If using TLS encryption, this must start with "https://"
25-
#[arg(long, value_name = "host:port")]
26-
pub address: String,
24+
/// TypeDB address to connect to (host:port). If using TLS encryption, this must start with "https://".
25+
#[arg(long, value_name = "host:port", conflicts_with_all = ["addresses", "address_translation"])]
26+
pub address: Option<String>,
27+
28+
/// A comma-separated list of TypeDB replica addresses of a single cluster to connect to.
29+
#[arg(long, value_name = "host1:port1,host2:port2", conflicts_with_all = ["address", "address_translation"])]
30+
pub addresses: Option<String>,
31+
32+
/// A comma-separated list of public=private address pairs. Public addresses are the user-facing
33+
/// addresses of the replicas, and private addresses are the addresses used for the server-side
34+
/// connection between replicas.
35+
#[arg(long, value_name = "public=private,...", conflicts_with_all = ["address", "addresses"])]
36+
pub address_translation: Option<String>,
37+
38+
/// If used in a cluster environment, disables attempts to redirect requests to server replicas,
39+
/// limiting Console to communicate only with the single address specified in the `address`
40+
/// argument.
41+
/// Use for administrative / debug purposes to test a specific replica only: this option will
42+
/// lower the success rate of Console's operations in production.
43+
#[arg(long = "replication-disabled", default_value = "false")]
44+
pub replication_disabled: bool,
2745

2846
/// Username for authentication
2947
#[arg(long, value_name = "username")]
@@ -45,8 +63,8 @@ pub struct Args {
4563

4664
/// Disable error reporting. Error reporting helps TypeDB improve by reporting
4765
/// errors and crashes to the development team.
48-
#[arg(long = "diagnostics-disable", default_value = "false")]
49-
pub diagnostics_disable: bool,
66+
#[arg(long = "diagnostics-disabled", default_value = "false")]
67+
pub diagnostics_disabled: bool,
5068

5169
/// Print the Console binary version
5270
#[arg(long = "version")]

src/main.rs

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use std::{
1717
rc::Rc,
1818
sync::Arc,
1919
};
20-
20+
use std::collections::HashMap;
2121
use clap::Parser;
2222
use home::home_dir;
2323
use rustyline::error::ReadlineError;
@@ -99,10 +99,11 @@ fn main() {
9999
if args.password.is_none() {
100100
args.password = Some(LineReaderHidden::new().readline(&format!("password for '{}': ", args.username)));
101101
}
102-
if !args.diagnostics_disable {
102+
if !args.diagnostics_disabled {
103103
init_diagnostics()
104104
}
105-
if !args.tls_disabled && !args.address.starts_with("https:") {
105+
let address_info = parse_addresses(&args);
106+
if !args.tls_disabled && !address_info.only_https {
106107
println!(
107108
"\
108109
TLS connections can only be enabled when connecting to HTTPS endpoints, for example using 'https://<ip>:port'. \
@@ -113,10 +114,11 @@ fn main() {
113114
}
114115
let runtime = BackgroundRuntime::new();
115116
let tls_root_ca_path = args.tls_root_ca.as_ref().map(|value| Path::new(value));
117+
let driver_options = DriverOptions::new().use_replication(!args.replication_disabled).is_tls_enabled(!args.tls_disabled).tls_root_ca(tls_root_ca_path).unwrap();
116118
let driver = match runtime.run(TypeDBDriver::new(
117-
Addresses::try_from_address_str(args.address).unwrap(),
119+
address_info.addresses,
118120
Credentials::new(&args.username, args.password.as_ref().unwrap()),
119-
DriverOptions::new().is_tls_enabled(!args.tls_disabled).tls_root_ca(tls_root_ca_path).unwrap(),
121+
driver_options,
120122
)) {
121123
Ok(driver) => Arc::new(driver),
122124
Err(err) => {
@@ -425,6 +427,40 @@ fn transaction_type_str(transaction_type: TransactionType) -> &'static str {
425427
}
426428
}
427429

430+
struct AddressInfo {
431+
only_https: bool,
432+
addresses: Addresses,
433+
}
434+
435+
fn parse_addresses(args: &Args) -> AddressInfo {
436+
if let Some(address) = &args.address {
437+
AddressInfo {only_https: is_https_address(address), addresses: Addresses::try_from_address_str(address).unwrap() }
438+
} else if let Some(addresses) = &args.addresses {
439+
let split = addresses.split(',').map(str::to_string).collect::<Vec<_>>();
440+
println!("Split: {split:?}");
441+
let only_https = split.iter().all(|address| is_https_address(address));
442+
AddressInfo {only_https, addresses: Addresses::try_from_addresses_str(split).unwrap() }
443+
} else if let Some(translation) = &args.address_translation {
444+
let mut map = HashMap::new();
445+
let mut only_https = true;
446+
for pair in translation.split(',') {
447+
let (public_address, private_address) = pair
448+
.split_once('=')
449+
.unwrap_or_else(|| panic!("Invalid address pair: {pair}. Must be of form public=private"));
450+
only_https = only_https && is_https_address(public_address);
451+
map.insert(public_address.to_string(), private_address.to_string());
452+
}
453+
println!("Translation map:: {map:?}");
454+
AddressInfo {only_https, addresses: Addresses::try_from_translation_str(map).unwrap() }
455+
} else {
456+
panic!("At least one of --address, --addresses, or --address-translation must be provided.");
457+
}
458+
}
459+
460+
fn is_https_address(address: &str) -> bool {
461+
address.starts_with("https:")
462+
}
463+
428464
fn init_diagnostics() {
429465
let _ = sentry::init((
430466
DIAGNOSTICS_REPORTING_URI,

0 commit comments

Comments
 (0)