Skip to content

Commit 0a68efa

Browse files
authored
Merge pull request #194 from MarnixKuijs/alpn
Alpn Support
2 parents 3377126 + 323c866 commit 0a68efa

File tree

7 files changed

+142
-17
lines changed

7 files changed

+142
-17
lines changed

.circleci/config.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
- /usr/local/cargo/registry/index
2222
- restore_cache:
2323
key: deps-<< parameters.image >>-{{ checksum "Cargo.lock" }}
24-
- run: cargo test
24+
- run: cargo test --features alpn
2525
- run: rustdoc --test README.md -L target/debug/deps -L target/debug
2626
- save_cache:
2727
key: deps-<< parameters.image >>-{{ checksum "Cargo.lock" }}
@@ -34,7 +34,7 @@ jobs:
3434
version:
3535
type: string
3636
macos:
37-
xcode: "9.4.1"
37+
xcode: "10.0.0"
3838
environment:
3939
RUST_BACKTRACE: 1
4040
RUSTFLAGS: -D warnings
@@ -51,7 +51,7 @@ jobs:
5151
- ~/.cargo/registry/index
5252
- restore_cache:
5353
key: macos-deps-<< parameters.version >>
54-
- run: cargo test
54+
- run: cargo test --features alpn
5555
- save_cache:
5656
key: macos-deps-<< parameters.version >>
5757
paths:

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ readme = "README.md"
99

1010
[features]
1111
vendored = ["openssl/vendored"]
12+
alpn = ["security-framework/alpn", "openssl/v102"]
1213

1314
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
1415
security-framework = "2.0.0"

src/imp/openssl.rs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ use self::openssl::ssl::{
1010
self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
1111
SslVerifyMode,
1212
};
13-
use self::openssl::x509::{X509, store::X509StoreBuilder, X509VerifyResult};
13+
use self::openssl::x509::{store::X509StoreBuilder, X509VerifyResult, X509};
1414
use std::error;
1515
use std::fmt;
1616
use std::io;
1717
use std::sync::Once;
1818

19-
use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
2019
use self::openssl::pkey::Private;
20+
use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
2121

2222
#[cfg(have_min_max_version)]
2323
fn supported_protocols(
@@ -274,6 +274,26 @@ impl TlsConnector {
274274
}
275275
}
276276

277+
#[cfg(feature = "alpn")]
278+
{
279+
if !builder.alpn.is_empty() {
280+
// Wire format is each alpn preceded by its length as a byte.
281+
let mut alpn_wire_format = Vec::with_capacity(
282+
builder
283+
.alpn
284+
.iter()
285+
.map(|s| s.as_bytes().len())
286+
.sum::<usize>()
287+
+ builder.alpn.len(),
288+
);
289+
for alpn in builder.alpn.iter().map(|s| s.as_bytes()) {
290+
alpn_wire_format.push(alpn.len() as u8);
291+
alpn_wire_format.extend(alpn);
292+
}
293+
connector.set_alpn_protos(&alpn_wire_format)?;
294+
}
295+
}
296+
277297
#[cfg(target_os = "android")]
278298
load_android_root_certs(&mut connector)?;
279299

@@ -305,8 +325,7 @@ impl TlsConnector {
305325

306326
impl fmt::Debug for TlsConnector {
307327
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
308-
fmt
309-
.debug_struct("TlsConnector")
328+
fmt.debug_struct("TlsConnector")
310329
// n.b. SslConnector is a newtype on SslContext which implements a noop Debug so it's omitted
311330
.field("use_sni", &self.use_sni)
312331
.field("accept_invalid_hostnames", &self.accept_invalid_hostnames)
@@ -367,6 +386,15 @@ impl<S: io::Read + io::Write> TlsStream<S> {
367386
Ok(self.0.ssl().peer_certificate().map(Certificate))
368387
}
369388

389+
#[cfg(feature = "alpn")]
390+
pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
391+
Ok(self
392+
.0
393+
.ssl()
394+
.selected_alpn_protocol()
395+
.map(|alpn| alpn.to_vec()))
396+
}
397+
370398
pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
371399
let cert = if self.0.ssl().is_server() {
372400
self.0.ssl().certificate().map(|x| x.to_owned())

src/imp/schannel.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ impl Identity {
8686
return Err(io::Error::new(
8787
io::ErrorKind::InvalidInput,
8888
"No identity found in PKCS #12 archive",
89-
).into());
89+
)
90+
.into());
9091
}
9192
};
9293

@@ -112,7 +113,8 @@ impl Certificate {
112113
Err(_) => Err(io::Error::new(
113114
io::ErrorKind::InvalidInput,
114115
"PEM representation contains non-UTF-8 bytes",
115-
).into()),
116+
)
117+
.into()),
116118
}
117119
}
118120

@@ -186,6 +188,7 @@ pub struct TlsConnector {
186188
accept_invalid_hostnames: bool,
187189
accept_invalid_certs: bool,
188190
disable_built_in_roots: bool,
191+
alpn: Vec<String>,
189192
}
190193

191194
impl TlsConnector {
@@ -205,6 +208,7 @@ impl TlsConnector {
205208
accept_invalid_hostnames: builder.accept_invalid_hostnames,
206209
accept_invalid_certs: builder.accept_invalid_certs,
207210
disable_built_in_roots: builder.disable_built_in_roots,
211+
alpn: builder.alpn.clone(),
208212
})
209213
}
210214

@@ -249,6 +253,14 @@ impl TlsConnector {
249253
))
250254
});
251255
}
256+
#[cfg(feature = "alpn")]
257+
{
258+
if !self.alpn.is_empty() {
259+
builder.request_application_protocols(
260+
&self.alpn.iter().map(|s| s.as_bytes()).collect::<Vec<_>>(),
261+
);
262+
}
263+
}
252264
match builder.connect(cred, stream) {
253265
Ok(s) => Ok(TlsStream(s)),
254266
Err(e) => Err(e.into()),
@@ -319,6 +331,11 @@ impl<S: io::Read + io::Write> TlsStream<S> {
319331
}
320332
}
321333

334+
#[cfg(feature = "alpn")]
335+
pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
336+
Ok(self.0.negotiated_application_protocol()?)
337+
}
338+
322339
pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
323340
let cert = if self.0.is_server() {
324341
self.0.certificate()

src/imp/security_framework.rs

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ use self::security_framework::import_export::{ImportedIdentity, Pkcs12ImportOpti
1010
use self::security_framework::secure_transport::{
1111
self, ClientBuilder, SslConnectionType, SslContext, SslProtocol, SslProtocolSide,
1212
};
13-
use self::security_framework_sys::base::errSecIO;
13+
use self::security_framework_sys::base::{errSecIO, errSecParam};
1414
use self::tempfile::TempDir;
1515
use std::error;
1616
use std::fmt;
1717
use std::io;
18+
use std::str;
1819
use std::sync::Mutex;
1920
use std::sync::Once;
2021

@@ -23,11 +24,11 @@ use self::security_framework::os::macos::certificate::{PropertyType, SecCertific
2324
#[cfg(not(target_os = "ios"))]
2425
use self::security_framework::os::macos::certificate_oids::CertificateOid;
2526
#[cfg(not(target_os = "ios"))]
26-
use self::security_framework::os::macos::import_export::{ImportOptions, SecItems, Pkcs12ImportOptionsExt};
27+
use self::security_framework::os::macos::import_export::{
28+
ImportOptions, Pkcs12ImportOptionsExt, SecItems,
29+
};
2730
#[cfg(not(target_os = "ios"))]
2831
use self::security_framework::os::macos::keychain::{self, KeychainSettings, SecKeychain};
29-
#[cfg(not(target_os = "ios"))]
30-
use self::security_framework_sys::base::errSecParam;
3132

3233
use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
3334

@@ -128,10 +129,10 @@ impl Identity {
128129
keychain
129130
}
130131
};
131-
let imports = Pkcs12ImportOptions::new()
132-
.passphrase(pass)
133-
.keychain(keychain)
134-
.import(buf)?;
132+
let mut import_opts = Pkcs12ImportOptions::new();
133+
// Method shadowed by deprecated method.
134+
<Pkcs12ImportOptions as Pkcs12ImportOptionsExt>::keychain(&mut import_opts, keychain);
135+
let imports = import_opts.passphrase(pass).import(buf)?;
135136
Ok(imports)
136137
}
137138

@@ -263,6 +264,7 @@ pub struct TlsConnector {
263264
danger_accept_invalid_hostnames: bool,
264265
danger_accept_invalid_certs: bool,
265266
disable_built_in_roots: bool,
267+
alpn: Vec<String>,
266268
}
267269

268270
impl TlsConnector {
@@ -280,6 +282,7 @@ impl TlsConnector {
280282
danger_accept_invalid_hostnames: builder.accept_invalid_hostnames,
281283
danger_accept_invalid_certs: builder.accept_invalid_certs,
282284
disable_built_in_roots: builder.disable_built_in_roots,
285+
alpn: builder.alpn.clone(),
283286
})
284287
}
285288

@@ -303,6 +306,13 @@ impl TlsConnector {
303306
builder.danger_accept_invalid_certs(self.danger_accept_invalid_certs);
304307
builder.trust_anchor_certificates_only(self.disable_built_in_roots);
305308

309+
#[cfg(feature = "alpn")]
310+
{
311+
if !self.alpn.is_empty() {
312+
builder.alpn_protocols(&self.alpn.iter().map(String::as_str).collect::<Vec<_>>());
313+
}
314+
}
315+
306316
match builder.handshake(domain, stream) {
307317
Ok(stream) => Ok(TlsStream { stream, cert: None }),
308318
Err(e) => Err(e.into()),
@@ -388,6 +398,27 @@ impl<S: io::Read + io::Write> TlsStream<S> {
388398
Ok(trust.certificate_at_index(0).map(Certificate))
389399
}
390400

401+
#[cfg(feature = "alpn")]
402+
pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
403+
match self.stream.context().alpn_protocols() {
404+
Ok(protocols) => {
405+
// Per RFC7301, "ProtocolNameList" MUST contain exactly one "ProtocolName".
406+
assert!(protocols.len() < 2);
407+
408+
if protocols.is_empty() {
409+
// Not sure this is actually possible.
410+
Ok(None)
411+
} else {
412+
Ok(Some(protocols.into_iter().next().unwrap().into_bytes()))
413+
}
414+
}
415+
// The macOS API appears to return `errSecParam` whenever no ALPN was negotiated, both
416+
// when it isn't attempted and when it isn't successful.
417+
Err(e) if e.code() == errSecParam => Ok(None),
418+
Err(other) => Err(Error::from(other)),
419+
}
420+
}
421+
391422
#[cfg(target_os = "ios")]
392423
pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
393424
Ok(None)

src/lib.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ pub struct TlsConnectorBuilder {
328328
accept_invalid_hostnames: bool,
329329
use_sni: bool,
330330
disable_built_in_roots: bool,
331+
alpn: Vec<String>,
331332
}
332333

333334
impl TlsConnectorBuilder {
@@ -376,6 +377,15 @@ impl TlsConnectorBuilder {
376377
self
377378
}
378379

380+
/// Request specific protocols through ALPN (Application-Layer Protocol Negotiation).
381+
///
382+
/// Defaults to none
383+
#[cfg(feature = "alpn")]
384+
pub fn request_alpns(&mut self, protocols: &[&str]) -> &mut TlsConnectorBuilder {
385+
self.alpn = protocols.iter().map(|s| (*s).to_owned()).collect();
386+
self
387+
}
388+
379389
/// Controls the use of certificate validation.
380390
///
381391
/// Defaults to `false`.
@@ -464,6 +474,7 @@ impl TlsConnector {
464474
accept_invalid_certs: false,
465475
accept_invalid_hostnames: false,
466476
disable_built_in_roots: false,
477+
alpn: vec![],
467478
}
468479
}
469480

@@ -644,6 +655,12 @@ impl<S: io::Read + io::Write> TlsStream<S> {
644655
Ok(self.0.tls_server_end_point()?)
645656
}
646657

658+
/// Returns the negotiated ALPN protocols
659+
#[cfg(feature = "alpn")]
660+
pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>> {
661+
Ok(self.0.negotiated_alpn()?)
662+
}
663+
647664
/// Shuts down the TLS session.
648665
pub fn shutdown(&mut self) -> io::Result<()> {
649666
self.0.shutdown()?;

src/test.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use hex;
22
#[allow(unused_imports)]
33
use std::io::{Read, Write};
44
use std::net::{TcpListener, TcpStream};
5+
use std::string::String;
56
use std::thread;
67

78
use super::*;
@@ -417,4 +418,34 @@ mod tests {
417418

418419
p!(j.join());
419420
}
421+
422+
#[test]
423+
#[cfg(feature = "alpn")]
424+
fn alpn_google_h2() {
425+
let builder = p!(TlsConnector::builder().request_alpns(&["h2"]).build());
426+
let s = p!(TcpStream::connect("google.com:443"));
427+
let socket = p!(builder.connect("google.com", s));
428+
let alpn = p!(socket.negotiated_alpn());
429+
assert_eq!(alpn, Some(b"h2".to_vec()));
430+
}
431+
432+
#[test]
433+
#[cfg(feature = "alpn")]
434+
fn alpn_google_invalid() {
435+
let builder = p!(TlsConnector::builder().request_alpns(&["h2c"]).build());
436+
let s = p!(TcpStream::connect("google.com:443"));
437+
let socket = p!(builder.connect("google.com", s));
438+
let alpn = p!(socket.negotiated_alpn());
439+
assert_eq!(alpn, None);
440+
}
441+
442+
#[test]
443+
#[cfg(feature = "alpn")]
444+
fn alpn_google_none() {
445+
let builder = p!(TlsConnector::new());
446+
let s = p!(TcpStream::connect("google.com:443"));
447+
let socket = p!(builder.connect("google.com", s));
448+
let alpn = p!(socket.negotiated_alpn());
449+
assert_eq!(alpn, None);
450+
}
420451
}

0 commit comments

Comments
 (0)