Skip to content

Commit 5058c76

Browse files
committed
refactor(lsp): improve performance by optimizing byte handling and method checks.
- Replace string conversions with direct byte operations for efficiency - Simplify method matching using memchr::find for lower overhead - Refactor request tracking to remove unnecessary expected argument - Use pointer_mut for JSON mutation in pid handler - Avoid redundant error handling and conversions - Reduce allocations in send_message by writing bytes directly
1 parent 8c16143 commit 5058c76

File tree

3 files changed

+88
-89
lines changed

3 files changed

+88
-89
lines changed

src/lsp/binding.rs

Lines changed: 63 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::{config::ProxyConfig, proxy::Pair};
2-
use memchr::memmem::find_iter;
2+
use memchr::memmem::{find, find_iter};
33
use serde_json::{Value, json};
44
use std::{path::PathBuf, process::Stdio, str, sync::Arc};
55
use tokio::{
@@ -34,22 +34,21 @@ pub fn redirect_uri(
3434
}
3535
}
3636

37-
trace!(?from_path, ?to_path);
37+
trace!(from=?String::from_utf8(from_path.to_vec()), to=?String::from_utf8(to_path.to_vec()));
3838

3939
let occurrences = find_iter(&raw_bytes, from_path);
4040
let from_n = from_path.len();
4141
let mut new_bytes: Bytes = Bytes::new();
42+
let mut last = 0;
4243

4344
for occurr in occurrences {
44-
let before = if occurr > 0 {
45-
&raw_bytes[..occurr]
46-
} else {
47-
&Bytes::new()
48-
};
49-
let after = &raw_bytes[occurr + from_n..];
45+
let before = &raw_bytes[last..occurr];
46+
last = occurr + from_n;
5047
// add the new text and join
51-
new_bytes = Bytes::from([before, to_path, after].concat());
48+
new_bytes = Bytes::from([&new_bytes, before, to_path].concat());
5249
}
50+
let after = &raw_bytes[last..];
51+
new_bytes = Bytes::from([&new_bytes, after].concat());
5352

5453
*raw_bytes = new_bytes;
5554

@@ -115,13 +114,13 @@ impl RequestTracker {
115114
self.map.write().await.insert(id, method.to_string());
116115
}
117116

118-
async fn take_if_match(&self, id: u64, expected: &str) -> bool {
117+
async fn take_if_match(&self, id: u64) -> bool {
119118
let mut map = self.map.write().await;
120-
let exists = map.get(&id).map(|m| m == expected).unwrap_or(false);
121-
if exists {
119+
if map.get(&id).is_some() {
122120
map.remove(&id);
121+
return true;
123122
}
124-
exists
123+
false
125124
}
126125

127126
pub async fn check_for_methods(
@@ -135,79 +134,81 @@ impl RequestTracker {
135134
return Ok(());
136135
}
137136

138-
//textDocument/declaration
139-
140137
match pair {
141138
Pair::Server => {
142-
let mut v: Value = serde_json::from_slice(&raw_bytes)?;
139+
// Early return
140+
if self.map.read().await.is_empty() {
141+
trace!("Nothing expecting response, skipping method");
142+
return Ok(());
143+
}
144+
145+
let mut v: Value = serde_json::from_slice(raw_bytes.as_ref())?;
143146
trace!(server_response=%v, "received");
144147

145148
// Check if this is a response to a tracked request
146149
if let Some(id) = v.get("id").and_then(Value::as_u64) {
147-
for method in methods {
148-
debug!("Checking for {method} method");
149-
150-
let matches = self.take_if_match(id, *method).await;
151-
debug!(%matches);
152-
if matches {
153-
trace!(%id, "matches");
154-
if let Some(results) = v.get_mut("result").and_then(Value::as_array_mut)
155-
{
156-
trace!(?results);
157-
for result in results {
158-
if let Some(uri_val) =
159-
result.get("uri").and_then(|u| u.as_str())
160-
{
161-
if !(uri_val.contains(&self.config.local_path)) {
162-
debug!(%uri_val);
163-
let new_uri =
164-
self.bind_library(uri_val.to_string()).await?;
165-
debug!("file://{}", new_uri);
166-
167-
Self::modify_uri(result, &new_uri);
168-
}
150+
let matches = self.take_if_match(id).await;
151+
debug!(%matches);
152+
if matches {
153+
trace!(%id, "matches");
154+
if let Some(results) = v.get_mut("result").and_then(Value::as_array_mut) {
155+
trace!(?results);
156+
for result in results {
157+
if let Some(uri_val) = result.get("uri").and_then(|u| u.as_str()) {
158+
if !(uri_val.contains(&self.config.local_path)) {
159+
debug!(%uri_val);
160+
let new_uri = self.bind_library(uri_val).await?;
161+
debug!("file://{}", new_uri);
162+
163+
Self::modify_uri(result, &new_uri);
169164
}
170165
}
171-
172-
if let Some(vstr) = v.as_str() {
173-
*raw_bytes = Bytes::from(vstr.as_bytes().to_owned());
174-
} else {
175-
error!(%v ,"error converting to str");
176-
}
177-
} else {
178-
trace!("result content not found");
179166
}
167+
168+
*raw_bytes = Bytes::from(serde_json::to_vec(&v)?);
169+
} else {
170+
trace!("result content not found");
180171
}
181172
}
182173
}
183174
}
184175

185176
Pair::Client => {
186-
let v: Value = serde_json::from_slice(&raw_bytes)?;
177+
// Early check to avoid parsing
178+
let mut method_found = "";
179+
for method in methods {
180+
debug!("Checking for {method} method");
181+
let expected = &[b"\"method\":\"", method.as_bytes(), b"\""].concat();
182+
if find(raw_bytes, expected).is_some() {
183+
method_found = method;
184+
break;
185+
}
186+
}
187+
188+
if method_found.is_empty() {
189+
debug!("Any method that required redirection was not found, skipping patch");
190+
return Ok(());
191+
}
192+
193+
debug!(%method_found);
194+
195+
let v: Value = serde_json::from_slice(raw_bytes.as_ref())?;
187196
trace!(client_request=%v, "received");
188197

189198
debug!("Checking for id");
190199
if let Some(id) = v.get("id").and_then(Value::as_u64) {
191200
debug!(%id);
192-
193-
if let Some(req_method) = v.get("method").and_then(Value::as_str) {
194-
trace!(%req_method);
195-
// Only track expected methods if URI matches
196-
for method in methods {
197-
if req_method == *method {
198-
debug!(%id, "Storing");
199-
self.track(id, *method).await;
200-
}
201-
}
202-
}
201+
// Only track expected methods if URI matches
202+
self.track(id, method_found).await;
203+
debug!(%id, "Storing");
203204
}
204205
}
205206
}
206207

207208
Ok(())
208209
}
209210

210-
async fn bind_library(&self, uri: String) -> std::io::Result<String> {
211+
async fn bind_library(&self, uri: &str) -> std::io::Result<String> {
211212
let temp_dir = std::env::temp_dir().join("lspdock");
212213
trace!(temp_dir=%temp_dir.to_string_lossy());
213214

@@ -222,7 +223,7 @@ impl RequestTracker {
222223
} else {
223224
let relative_path = safe_path.strip_prefix("/").unwrap_or(&safe_path);
224225
trace!(%relative_path);
225-
let tmp_file_path = relative_path.to_string();
226+
let tmp_file_path = relative_path;
226227
temp_dir.join(tmp_file_path)
227228
};
228229

@@ -238,7 +239,7 @@ impl RequestTracker {
238239
let temp_uri_path = PathBuf::from(&temp_uri);
239240
debug!(%temp_uri);
240241
if !temp_uri_path.exists() {
241-
self.copy_file(safe_path.to_string(), &temp_uri).await?;
242+
self.copy_file(&safe_path, &temp_uri).await?;
242243
} else {
243244
debug!("File already exists, skipping copy. {}", temp_uri);
244245
}
@@ -247,7 +248,7 @@ impl RequestTracker {
247248
}
248249

249250
/// Copies a file from either the local filesystem or a Docker container.
250-
async fn copy_file(&self, path: String, destination: &str) -> std::io::Result<()> {
251+
async fn copy_file(&self, path: &str, destination: &str) -> std::io::Result<()> {
251252
// Only copy the file if the LSP is in a container
252253
debug!("Starting file copy from {} to {}", path, destination);
253254
let cmd = Command::new("docker")

src/lsp/parser.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,15 @@ pub async fn send_message(
134134
let len = message.len();
135135
debug!(%len, "Sending message");
136136
trace!(?message);
137-
let message_str = String::from_utf8(message.to_vec())?;
138-
let msg = format!("Content-Length: {len}\r\n\r\n{message_str}");
139-
140-
writer.write_all(msg.as_bytes()).await?;
137+
let msg = &[
138+
b"Content-Length: ",
139+
len.to_string().as_bytes(),
140+
b"\r\n\r\n",
141+
&message,
142+
]
143+
.concat();
144+
145+
writer.write_all(msg).await?;
141146
writer.flush().await?;
142147

143148
Ok(())

src/lsp/pid.rs

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use memchr::memmem::find;
22
use serde_json::{Value, json};
33
use sysinfo::{Pid, System};
44
use tokio_util::{bytes::Bytes, sync::CancellationToken};
5-
use tracing::{debug, error, info, trace, warn};
5+
use tracing::{debug, info, trace, warn};
66

77
pub struct PidHandler {
88
pid: Option<u64>,
@@ -27,31 +27,24 @@ impl PidHandler {
2727
&mut self,
2828
raw_bytes: &mut Bytes,
2929
) -> serde_json::error::Result<bool> {
30-
if find(raw_bytes, br#""method":"initialize""#).is_some() {
31-
debug!("Initialize method found, patching");
32-
trace!(?raw_bytes, "before patch");
30+
if find(raw_bytes, br#""method":"initialize""#).is_none() {
31+
trace!("Initialize method not found, skipping patch");
32+
return Ok(false);
33+
}
3334

34-
let mut v: Value = serde_json::from_slice(&raw_bytes)?;
35-
if let Some(process_id) = v
36-
.get_mut("params")
37-
.and_then(|params| params.get_mut("processId"))
38-
{
39-
self.pid = process_id.as_u64();
40-
trace!(self.pid, "captured PID");
41-
*process_id = json!("null");
42-
}
35+
debug!("Initialize method found, patching");
36+
trace!(?raw_bytes, "before patch");
4337

44-
if let Some(vstr) = v.as_str() {
45-
*raw_bytes = Bytes::from(vstr.as_bytes().to_owned());
46-
} else {
47-
error!(%v ,"error converting to str");
48-
}
49-
50-
trace!(?raw_bytes, "patched");
51-
return Ok(true);
38+
let mut v: Value = serde_json::from_slice(raw_bytes.as_ref())?;
39+
if let Some(process_id) = v.pointer_mut("/params/processId") {
40+
self.pid = process_id.as_u64();
41+
trace!(self.pid, "captured PID");
42+
*process_id = json!("null");
5243
}
53-
trace!("Initialize method not found, skipping patch");
54-
return Ok(false);
44+
*raw_bytes = Bytes::from(serde_json::to_vec(&v)?);
45+
46+
trace!(?raw_bytes, "patched");
47+
return Ok(true);
5548
}
5649

5750
/// Monitor periodically if the PID is running

0 commit comments

Comments
 (0)