Skip to content

Commit d231c7f

Browse files
committed
Direct send for none ciphers, prevent unnecessary data copy
1 parent e64eb3f commit d231c7f

File tree

7 files changed

+183
-78
lines changed

7 files changed

+183
-78
lines changed

src/crypto/cipher.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ pub enum CipherType {
263263
/// Category of ciphers
264264
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
265265
pub enum CipherCategory {
266+
/// No encryption
267+
None,
266268
/// Stream ciphers is used for OLD ShadowSocks protocol, which uses stream ciphers to encrypt data payloads
267269
Stream,
268270
/// AEAD ciphers is used in modern ShadowSocks protocol, which sends data in separate packets
@@ -542,6 +544,8 @@ impl CipherType {
542544
/// Get category of cipher
543545
pub fn category(self) -> CipherCategory {
544546
match self {
547+
CipherType::None | CipherType::Plain => CipherCategory::None,
548+
545549
CipherType::Aes128Gcm | CipherType::Aes256Gcm | CipherType::ChaCha20IetfPoly1305 => CipherCategory::Aead,
546550

547551
#[cfg(feature = "sodium")]

src/relay/dnsrelay/mod.rs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ fn should_forward_by_ptr_name(acl: &AccessControl, name: &Name) -> bool {
3131
let mut iter = name.iter().rev();
3232
let mut next = || match iter.next() {
3333
Some(label) => std::str::from_utf8(label).unwrap_or("*"),
34-
None => "0", // zero fill the missing labels
34+
None => "0", // zero fill the missing labels
3535
};
3636
if !"arpa".eq_ignore_ascii_case(next()) {
3737
return acl.is_default_in_proxy_list();
@@ -56,7 +56,14 @@ fn should_forward_by_ptr_name(acl: &AccessControl, name: &Name) -> bool {
5656
}
5757
}
5858
acl.check_ip_in_proxy_list(&IpAddr::V6(Ipv6Addr::new(
59-
segments[0], segments[1], segments[2], segments[3], segments[4], segments[5], segments[6], segments[7]
59+
segments[0],
60+
segments[1],
61+
segments[2],
62+
segments[3],
63+
segments[4],
64+
segments[5],
65+
segments[6],
66+
segments[7],
6067
)))
6168
}
6269
_ => acl.is_default_in_proxy_list(),
@@ -118,7 +125,7 @@ fn should_forward_by_response(
118125
} else {
119126
acl.is_default_in_proxy_list()
120127
}
121-
}}
128+
}};
122129
}
123130
macro_rules! examine_record {
124131
($rec:ident, $is_answer:expr) => {
@@ -132,8 +139,11 @@ fn should_forward_by_response(
132139
continue;
133140
}
134141
if $is_answer && !query.query_type().is_any() && $rec.record_type() != query.query_type() {
135-
warn!("local DNS response has inconsistent answer type {} for query {}",
136-
$rec.record_type(), query);
142+
warn!(
143+
"local DNS response has inconsistent answer type {} for query {}",
144+
$rec.record_type(),
145+
query
146+
);
137147
return true;
138148
}
139149
let forward = match $rec.rdata() {
@@ -153,7 +163,11 @@ fn should_forward_by_response(
153163
}
154164
for rec in local_response.answers() {
155165
if !names.contains(rec.name()) {
156-
warn!("local DNS response contains unexpected name {} for query {}", rec.name(), query);
166+
warn!(
167+
"local DNS response contains unexpected name {} for query {}",
168+
rec.name(),
169+
query
170+
);
157171
return true;
158172
}
159173
examine_record!(rec, true);
@@ -183,7 +197,10 @@ impl<Remote: upstream::Upstream> DnsRelay<Remote> {
183197
// Start querying name servers
184198
debug!(
185199
"attempting lookup of {:?} {} with ns {:?} and {:?}",
186-
query.query_type(), query.name(), local, remote
200+
query.query_type(),
201+
query.name(),
202+
local,
203+
remote
187204
);
188205

189206
let remote_response_fut = try_timeout(remote.lookup(query), Some(Duration::new(3, 0)));
@@ -274,7 +291,7 @@ impl<Remote: upstream::Upstream> DnsRelay<Remote> {
274291

275292
async fn run_tcp<Remote: upstream::Upstream + Send + Sync + 'static>(
276293
relay: Arc<DnsRelay<Remote>>,
277-
bind_addr: SocketAddr
294+
bind_addr: SocketAddr,
278295
) -> io::Result<()> {
279296
let mut listener = TcpListener::bind(&bind_addr).await?;
280297

@@ -302,7 +319,7 @@ async fn run_tcp<Remote: upstream::Upstream + Send + Sync + 'static>(
302319

303320
async fn run_udp<Remote: upstream::Upstream + Send + Sync + 'static>(
304321
relay: Arc<DnsRelay<Remote>>,
305-
bind_addr: SocketAddr
322+
bind_addr: SocketAddr,
306323
) -> io::Result<()> {
307324
let socket = create_udp_socket(&bind_addr).await?;
308325

@@ -379,12 +396,13 @@ pub async fn run(context: SharedContext) -> io::Result<()> {
379396
context: context.clone(),
380397
svr_cfg: move || balancer.pick_server().server_config().clone(),
381398
ns: config.remote_dns_addr.clone().expect("remote query DNS address"),
382-
}
399+
},
383400
});
384401

385402
future::select(
386403
tokio::spawn(run_tcp(relay.clone(), bind_addr)),
387404
tokio::spawn(run_udp(relay, bind_addr)),
388-
).await;
405+
)
406+
.await;
389407
Ok(())
390408
}

src/relay/dnsrelay/upstream.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ use byteorder::{BigEndian, ByteOrder};
1010
use rand::Rng;
1111
use tokio::{
1212
io::{AsyncReadExt, AsyncWriteExt},
13-
net::UdpSocket
13+
net::UdpSocket,
1414
};
1515
use trust_dns_proto::{
1616
op::{Message, Query},
17-
rr::{DNSClass, RecordType, Name, RData},
17+
rr::{DNSClass, Name, RData, RecordType},
1818
};
1919

2020
#[cfg(unix)]
@@ -25,10 +25,7 @@ use tokio::net::UnixStream;
2525
use crate::{
2626
config::{Config, ServerConfig},
2727
context::SharedContext,
28-
relay::{
29-
socks5::Address,
30-
tcprelay::ProxyStream,
31-
},
28+
relay::{socks5::Address, tcprelay::ProxyStream},
3229
};
3330

3431
#[derive(Debug)]
@@ -118,8 +115,8 @@ pub async fn write_message<T: AsyncWriteExt + Unpin>(stream: &mut T, message: &M
118115
}
119116

120117
async fn stream_lookup<T>(query: &Query, stream: &mut T) -> io::Result<Message>
121-
where
122-
T: AsyncReadExt + AsyncWriteExt + Unpin,
118+
where
119+
T: AsyncReadExt + AsyncWriteExt + Unpin,
123120
{
124121
write_message(stream, &generate_query_message(query)).await?;
125122
read_message(stream).await
@@ -133,10 +130,14 @@ pub struct UdpUpstream {
133130
#[async_trait]
134131
impl Upstream for UdpUpstream {
135132
async fn lookup(&self, query: &Query) -> io::Result<Message> {
136-
let mut socket = UdpSocket::bind(SocketAddr::new(match self.server {
137-
SocketAddr::V4(..) => IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
138-
SocketAddr::V6(..) => IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
139-
}, 0)).await?;
133+
let mut socket = UdpSocket::bind(SocketAddr::new(
134+
match self.server {
135+
SocketAddr::V4(..) => IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
136+
SocketAddr::V6(..) => IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
137+
},
138+
0,
139+
))
140+
.await?;
140141
socket.connect(self.server).await?;
141142
socket.send(&generate_query_message(query).to_vec()?).await?;
142143
let mut response = vec![0; 512];
@@ -165,14 +166,15 @@ pub struct ProxyTcpUpstream<F> {
165166

166167
impl<F> Debug for ProxyTcpUpstream<F> {
167168
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
168-
f.debug_struct("ProxyTcpUpstream")
169-
.field("ns", &self.ns)
170-
.finish()
169+
f.debug_struct("ProxyTcpUpstream").field("ns", &self.ns).finish()
171170
}
172171
}
173172

174173
#[async_trait]
175-
impl<F> Upstream for ProxyTcpUpstream<F> where F: Fn() -> ServerConfig + Send + Sync {
174+
impl<F> Upstream for ProxyTcpUpstream<F>
175+
where
176+
F: Fn() -> ServerConfig + Send + Sync,
177+
{
176178
async fn lookup(&self, query: &Query) -> io::Result<Message> {
177179
let mut stream = ProxyStream::connect_proxied(self.context.clone(), &(self.svr_cfg)(), &self.ns).await?;
178180
stream_lookup(query, &mut stream).await

src/relay/tcprelay/crypto_io.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ use super::{
2828
};
2929

3030
enum DecryptedReader {
31+
None,
3132
Aead(AeadDecryptedReader),
3233
Stream(StreamDecryptedReader),
3334
}
3435

3536
enum EncryptedWriter {
37+
None,
3638
Aead(AeadEncryptedWriter),
3739
Stream(StreamEncryptedWriter),
3840
}
@@ -62,9 +64,14 @@ impl<S> CryptoStream<S> {
6264
/// Create a new CryptoStream with the underlying stream connection
6365
pub fn new(context: SharedContext, stream: S, svr_cfg: &ServerConfig) -> CryptoStream<S> {
6466
let method = svr_cfg.method();
67+
if method.category() == CipherCategory::None {
68+
return CryptoStream::<S>::new_none(stream);
69+
}
70+
6571
let prev_len = match method.category() {
6672
CipherCategory::Stream => method.iv_size(),
6773
CipherCategory::Aead => method.salt_size(),
74+
CipherCategory::None => 0,
6875
};
6976

7077
let iv = match method.category() {
@@ -92,12 +99,14 @@ impl<S> CryptoStream<S> {
9299
trace!("generated AEAD cipher salt {:?}", local_salt);
93100
local_salt
94101
}
102+
CipherCategory::None => Bytes::new(),
95103
};
96104

97105
let method = svr_cfg.method();
98106
let enc = match method.category() {
99107
CipherCategory::Stream => EncryptedWriter::Stream(StreamEncryptedWriter::new(method, svr_cfg.key(), iv)),
100108
CipherCategory::Aead => EncryptedWriter::Aead(AeadEncryptedWriter::new(method, svr_cfg.key(), iv)),
109+
CipherCategory::None => EncryptedWriter::None,
101110
};
102111

103112
CryptoStream {
@@ -108,6 +117,15 @@ impl<S> CryptoStream<S> {
108117
}
109118
}
110119

120+
fn new_none(stream: S) -> CryptoStream<S> {
121+
CryptoStream {
122+
stream,
123+
dec: Some(DecryptedReader::None),
124+
enc: EncryptedWriter::None,
125+
read_status: ReadStatus::Established,
126+
}
127+
}
128+
111129
/// Return a reference to the underlying stream
112130
pub fn get_ref(&self) -> &S {
113131
&self.stream
@@ -148,6 +166,7 @@ where
148166
trace!("got AEAD cipher salt {:?}", ByteStr::new(&buf));
149167
DecryptedReader::Aead(AeadDecryptedReader::new(method, key, &buf))
150168
}
169+
CipherCategory::None => DecryptedReader::None,
151170
};
152171

153172
self.dec = Some(dec);
@@ -162,6 +181,7 @@ where
162181
ready!(this.poll_read_handshake(ctx))?;
163182

164183
match *this.dec.as_mut().unwrap() {
184+
DecryptedReader::None => Pin::new(&mut this.stream).poll_read(ctx, buf),
165185
DecryptedReader::Aead(ref mut r) => r.poll_read_decrypted(ctx, &mut this.stream, buf),
166186
DecryptedReader::Stream(ref mut r) => r.poll_read_decrypted(ctx, &mut this.stream, buf),
167187
}
@@ -175,6 +195,7 @@ where
175195
fn priv_poll_write(self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
176196
let this = self.get_mut();
177197
match this.enc {
198+
EncryptedWriter::None => Pin::new(&mut this.stream).poll_write(ctx, buf),
178199
EncryptedWriter::Aead(ref mut w) => w.poll_write_encrypted(ctx, &mut this.stream, buf),
179200
EncryptedWriter::Stream(ref mut w) => w.poll_write_encrypted(ctx, &mut this.stream, buf),
180201
}

0 commit comments

Comments
 (0)