Skip to content

Commit 4357f3b

Browse files
authored
Persist ohttp keys (payjoin#616)
2 parents b2aba5e + c26bea6 commit 4357f3b

File tree

8 files changed

+140
-23
lines changed

8 files changed

+140
-23
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ target
55
Cargo.lock
66
.vscode
77
mutants.out*
8+
*.ikm

Cargo-minimal.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,7 @@ dependencies = [
16651665
"payjoin",
16661666
"redis",
16671667
"rustls 0.22.4",
1668+
"tempfile",
16681669
"tokio",
16691670
"tokio-rustls",
16701671
"tracing",

Cargo-recent.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,7 @@ dependencies = [
16651665
"payjoin",
16661666
"redis",
16671667
"rustls 0.22.4",
1668+
"tempfile",
16681669
"tokio",
16691670
"tokio-rustls",
16701671
"tracing",

payjoin-directory/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ _danger-local-https = ["hyper-rustls", "rustls", "tokio-rustls"]
1818

1919
[dependencies]
2020
anyhow = "1.0.71"
21-
bitcoin = { version = "0.32.4", features = ["base64"] }
21+
bitcoin = { version = "0.32.4", features = ["base64", "rand-std"] }
2222
bhttp = { version = "=0.5.1", features = ["http"] }
2323
futures = "0.3.17"
2424
http-body-util = "0.1.2"
@@ -33,3 +33,6 @@ tokio = { version = "1.12.0", features = ["full"] }
3333
tokio-rustls = { version = "0.25", features = ["ring"], default-features = false, optional = true }
3434
tracing = "0.1.37"
3535
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
36+
37+
[dev-dependencies]
38+
tempfile = "3.5.0"
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//! Manage the OHTTP key configuration
2+
3+
use std::fs;
4+
use std::path::{Path, PathBuf};
5+
6+
use anyhow::{anyhow, Result};
7+
use ohttp::hpke::{Aead, Kdf, Kem};
8+
use ohttp::SymmetricSuite;
9+
use tracing::info;
10+
11+
const KEY_ID: u8 = 1;
12+
const KEM: Kem = Kem::K256Sha256;
13+
const SYMMETRIC: &[SymmetricSuite] =
14+
&[SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)];
15+
16+
/// OHTTP server key configuration
17+
///
18+
/// This is combined so that the test path and the prod path both use the same
19+
/// code. The ServerKeyConfig.ikm is persisted to the configured path, and the
20+
/// server is used to run the directory server.
21+
#[derive(Debug, Clone)]
22+
pub struct ServerKeyConfig {
23+
ikm: [u8; 32],
24+
server: ohttp::Server,
25+
}
26+
27+
impl From<ServerKeyConfig> for ohttp::Server {
28+
fn from(value: ServerKeyConfig) -> Self { value.server }
29+
}
30+
31+
/// Generate a new OHTTP server key configuration
32+
pub fn gen_ohttp_server_config() -> Result<ServerKeyConfig> {
33+
let ikm = bitcoin::key::rand::random::<[u8; 32]>();
34+
let config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC))?;
35+
Ok(ServerKeyConfig { ikm, server: ohttp::Server::new(config)? })
36+
}
37+
38+
/// Persist an OHTTP Key Configuration to the default path
39+
pub fn persist_new_key_config(ohttp_config: ServerKeyConfig, dir: &Path) -> Result<PathBuf> {
40+
use std::fs::OpenOptions;
41+
use std::io::Write;
42+
43+
let key_path = key_path(dir);
44+
45+
let mut file = OpenOptions::new()
46+
.write(true)
47+
.create_new(true)
48+
.open(&key_path)
49+
.map_err(|e| anyhow!("Failed to create new OHTTP key file: {}", e))?;
50+
51+
file.write_all(&ohttp_config.ikm)
52+
.map_err(|e| anyhow!("Failed to write OHTTP keys to file: {}", e))?;
53+
info!("Saved OHTTP Key Configuration to {}", &key_path.display());
54+
55+
Ok(key_path)
56+
}
57+
58+
/// Read the configured server from the default path
59+
/// May panic if key exists but is the unexpected format.
60+
pub fn read_server_config(dir: &Path) -> Result<ServerKeyConfig> {
61+
let key_path = key_path(dir);
62+
let ikm: [u8; 32] = fs::read(&key_path)
63+
.map_err(|e| anyhow!("Failed to read OHTTP key file: {}", e))?
64+
.try_into()
65+
.expect("Key wrong size: expected 32 bytes");
66+
67+
let server_config = ohttp::KeyConfig::derive(KEY_ID, KEM, SYMMETRIC.to_vec(), &ikm)
68+
.expect("Failed to derive OHTTP keys from file");
69+
70+
info!("Loaded existing OHTTP Key Configuration from {}", key_path.display());
71+
Ok(ServerKeyConfig { ikm, server: ohttp::Server::new(server_config)? })
72+
}
73+
74+
/// Get the path to the key configuration file
75+
/// For now, default to [KEY_ID].ikm.
76+
/// In the future this might be able to save multiple keys named by KeyId.
77+
fn key_path(dir: &Path) -> PathBuf { dir.join(format!("{}.ikm", KEY_ID)) }
78+
79+
#[cfg(test)]
80+
mod tests {
81+
use super::*;
82+
83+
#[test]
84+
fn round_trip_server_config() {
85+
let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
86+
let ohttp_config = gen_ohttp_server_config().expect("Failed to generate server config");
87+
let _path = persist_new_key_config(ohttp_config.clone(), temp_dir.path())
88+
.expect("Failed to persist server config");
89+
let ohttp_config_again =
90+
read_server_config(temp_dir.path()).expect("Failed to read server config");
91+
assert_eq!(ohttp_config.ikm, ohttp_config_again.ikm);
92+
}
93+
}

payjoin-directory/src/lib.rs

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ use hyper_util::rt::TokioIo;
1515
use payjoin::directory::{ShortId, ShortIdError, ENCAPSULATED_MESSAGE_BYTES};
1616
use tokio::net::TcpListener;
1717
use tokio::sync::Mutex;
18-
use tracing::{debug, error, info, trace};
18+
use tracing::{debug, error, trace};
1919

2020
use crate::db::DbPool;
21+
pub mod key_config;
22+
pub use crate::key_config::*;
2123

2224
pub const DEFAULT_DIR_PORT: u16 = 8080;
2325
pub const DEFAULT_DB_HOST: &str = "localhost:6379";
@@ -43,11 +45,13 @@ pub async fn listen_tcp_with_tls_on_free_port(
4345
db_host: String,
4446
timeout: Duration,
4547
cert_key: (Vec<u8>, Vec<u8>),
48+
ohttp: ohttp::Server,
4649
) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxError>>), BoxError> {
4750
let listener = tokio::net::TcpListener::bind("[::]:0").await?;
4851
let port = listener.local_addr()?.port();
4952
println!("Directory server binding to port {}", listener.local_addr()?);
50-
let handle = listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await?;
53+
let handle =
54+
listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key, ohttp).await?;
5155
Ok((port, handle))
5256
}
5357

@@ -58,9 +62,10 @@ async fn listen_tcp_with_tls_on_listener(
5862
db_host: String,
5963
timeout: Duration,
6064
tls_config: (Vec<u8>, Vec<u8>),
65+
ohttp: ohttp::Server,
6166
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
6267
let pool = DbPool::new(timeout, db_host).await?;
63-
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
68+
let ohttp = Arc::new(Mutex::new(ohttp));
6469
let tls_acceptor = init_tls_acceptor(tls_config)?;
6570
// Spawn the connection handling loop in a separate task
6671
let handle = tokio::spawn(async move {
@@ -100,9 +105,10 @@ pub async fn listen_tcp(
100105
port: u16,
101106
db_host: String,
102107
timeout: Duration,
108+
ohttp: ohttp::Server,
103109
) -> Result<(), Box<dyn std::error::Error>> {
104110
let pool = DbPool::new(timeout, db_host).await?;
105-
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
111+
let ohttp = Arc::new(Mutex::new(ohttp));
106112
let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
107113
let listener = TcpListener::bind(bind_addr).await?;
108114
while let Ok((stream, _)) = listener.accept().await {
@@ -134,10 +140,11 @@ pub async fn listen_tcp_with_tls(
134140
db_host: String,
135141
timeout: Duration,
136142
cert_key: (Vec<u8>, Vec<u8>),
143+
ohttp: ohttp::Server,
137144
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
138145
let addr = format!("0.0.0.0:{}", port);
139146
let listener = tokio::net::TcpListener::bind(&addr).await?;
140-
listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await
147+
listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key, ohttp).await
141148
}
142149

143150
#[cfg(feature = "_danger-local-https")]
@@ -158,21 +165,6 @@ fn init_tls_acceptor(cert_key: (Vec<u8>, Vec<u8>)) -> Result<tokio_rustls::TlsAc
158165
Ok(TlsAcceptor::from(Arc::new(server_config)))
159166
}
160167

161-
fn init_ohttp() -> Result<ohttp::Server> {
162-
use ohttp::hpke::{Aead, Kdf, Kem};
163-
use ohttp::{KeyId, SymmetricSuite};
164-
165-
const KEY_ID: KeyId = 1;
166-
const KEM: Kem = Kem::K256Sha256;
167-
const SYMMETRIC: &[SymmetricSuite] =
168-
&[SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)];
169-
170-
// create or read from file
171-
let server_config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC))?;
172-
info!("Initialized a new OHTTP Key Configuration. GET /ohttp-keys to fetch it.");
173-
Ok(ohttp::Server::new(server_config)?)
174-
}
175-
176168
async fn serve_payjoin_directory(
177169
req: Request<Incoming>,
178170
pool: DbPool,

payjoin-directory/src/main.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use payjoin_directory::*;
44
use tracing_subscriber::filter::LevelFilter;
55
use tracing_subscriber::EnvFilter;
66

7+
const DEFAULT_KEY_CONFIG_DIR: &str = "ohttp_keys";
8+
79
#[tokio::main]
810
async fn main() -> Result<(), Box<dyn std::error::Error>> {
911
init_logging();
@@ -17,7 +19,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
1719

1820
let db_host = env::var("PJ_DB_HOST").unwrap_or_else(|_| DEFAULT_DB_HOST.to_string());
1921

20-
payjoin_directory::listen_tcp(dir_port, db_host, timeout).await
22+
let key_dir =
23+
std::env::var("PJ_OHTTP_KEY_DIR").map(std::path::PathBuf::from).unwrap_or_else(|_| {
24+
let key_dir = std::path::PathBuf::from(DEFAULT_KEY_CONFIG_DIR);
25+
std::fs::create_dir_all(&key_dir).expect("Failed to create key directory");
26+
key_dir
27+
});
28+
29+
let ohttp = match key_config::read_server_config(&key_dir) {
30+
Ok(config) => config,
31+
Err(_) => {
32+
let ohttp_config = key_config::gen_ohttp_server_config()?;
33+
let path = key_config::persist_new_key_config(ohttp_config, &key_dir)?;
34+
println!("Generated new key configuration at {}", path.display());
35+
key_config::read_server_config(&key_dir).expect("Failed to read newly generated config")
36+
}
37+
};
38+
39+
payjoin_directory::listen_tcp(dir_port, db_host, timeout, ohttp.into()).await
2140
}
2241

2342
fn init_logging() {

payjoin-test-utils/src/lib.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,14 @@ pub async fn init_directory(
131131
> {
132132
println!("Database running on {}", db_host);
133133
let timeout = Duration::from_secs(2);
134-
payjoin_directory::listen_tcp_with_tls_on_free_port(db_host, timeout, local_cert_key).await
134+
let ohttp_server = payjoin_directory::gen_ohttp_server_config()?;
135+
payjoin_directory::listen_tcp_with_tls_on_free_port(
136+
db_host,
137+
timeout,
138+
local_cert_key,
139+
ohttp_server.into(),
140+
)
141+
.await
135142
}
136143

137144
/// generate or get a DER encoded localhost cert and key.

0 commit comments

Comments
 (0)