Skip to content

Commit c781b49

Browse files
authored
Check files in parallel with crossbeam (#6783)
* chore: add `crossbeam-channel` as a dependency for the mpmc channel * refactor: package up all the git repository details in `Repository` * refactor: encapsulate logic for cloning repositories * refactor: add `check_diff_for_file` to support parallel file checks Processing all files in parallel is more efficient than checking each repo in parallel. One issue with the current design of processing each repo in parallel is that threads that process smaller repos end early and don't help process files from larger repos. For example, r-l/rust is a large repo that takes a long time to check because there's only one thread working on it. Besides enabling us to process all files in parallel, `check_diff_for_file` return a `Result<(), (Diff, ..)>` instead of a `u8`, which is a more idiomatic way to represent any errors we find when checking for diffs. * refactor: process files in parallel and report errors at the end Now we'll process each file in parallel instead of processing each repo in parallel, and after checking all repositories we'll report on any errors that we've found. * feat: Add `worker_threads` option This option control how many threads process files in parallel. Setting the default to 16 as that's a common multiple of CPU cores.
1 parent a57ce88 commit c781b49

File tree

5 files changed

+252
-113
lines changed

5 files changed

+252
-113
lines changed

check_diff/Cargo.lock

Lines changed: 17 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

check_diff/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
1212
tempfile = "3"
1313
walkdir = "2.5.0"
1414
diffy = "0.4.0"
15+
crossbeam-channel = "0.5.15"

check_diff/src/lib.rs

Lines changed: 197 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
use std::borrow::Cow;
2+
use std::collections::HashMap;
23
use std::env;
34
use std::fmt::{Debug, Display};
45
use std::io::{self, Write};
56
use std::path::{Path, PathBuf};
67
use std::process::{Command, Stdio};
78
use std::str::FromStr;
8-
use tracing::{debug, error, info, trace};
9+
use std::sync::{Arc, Mutex};
10+
use tempfile::tempdir;
11+
use tracing::{debug, info, trace, warn};
912
use walkdir::WalkDir;
1013

1114
#[derive(Debug, Clone, Copy)]
@@ -411,6 +414,46 @@ fn create_config_arg<T: AsRef<str>>(configs: Option<&[T]>) -> Cow<'static, str>
411414

412415
Cow::Owned(result)
413416
}
417+
418+
pub struct Repository<P> {
419+
/// Name of the repository
420+
name: String,
421+
/// Path to the repository on the local file system
422+
dir_path: P,
423+
}
424+
425+
impl<P> Repository<P> {
426+
/// Initialize a new Repository
427+
pub fn new(git_url: &str, dir_path: P) -> Self {
428+
let name = get_repo_name(git_url).to_string();
429+
Self { name, dir_path }
430+
}
431+
432+
/// Get the `name` of the repository
433+
pub fn name(&self) -> &str {
434+
&self.name
435+
}
436+
437+
/// Get the absolute path to where this repository was cloned
438+
pub fn path(&self) -> &Path
439+
where
440+
P: AsRef<Path>,
441+
{
442+
self.dir_path.as_ref()
443+
}
444+
445+
/// Get the relative path of a file contained in this repository
446+
pub fn relative_path<'f, F>(&self, file: &'f F) -> &'f Path
447+
where
448+
P: AsRef<Path>,
449+
F: AsRef<Path>,
450+
{
451+
file.as_ref()
452+
.strip_prefix(self.dir_path.as_ref())
453+
.unwrap_or(file.as_ref())
454+
}
455+
}
456+
414457
/// Clone a git repository
415458
///
416459
/// Parameters:
@@ -641,76 +684,111 @@ pub fn search_for_rs_files(repo: &Path) -> impl Iterator<Item = PathBuf> {
641684
})
642685
}
643686

687+
/// Encapsulate the logic used to clone repositories for the diff check
688+
pub fn clone_repositories_for_diff_check(
689+
repositories: &[&str],
690+
) -> Vec<Repository<tempfile::TempDir>> {
691+
// Use a Hashmap to deduplicate any repositories
692+
let map = Arc::new(Mutex::new(HashMap::new()));
693+
694+
std::thread::scope(|s| {
695+
for url in repositories {
696+
let map = Arc::clone(&map);
697+
698+
s.spawn(move || {
699+
let repo_name = get_repo_name(url);
700+
info!("Processing repo: {repo_name}");
701+
let Ok(tmp_dir) = tempdir() else {
702+
warn!(
703+
"Failed to create a tempdir for {}. Can't check formatting diff for {}",
704+
&url, repo_name
705+
);
706+
return;
707+
};
708+
709+
let Ok(_) = clone_git_repo(url, tmp_dir.path()) else {
710+
warn!(
711+
"Failed to clone repo {}. Can't check formatting diff for {}",
712+
&url, repo_name
713+
);
714+
return;
715+
};
716+
717+
let repo = Repository::new(url, tmp_dir);
718+
map.lock().unwrap().insert(repo_name.to_string(), repo);
719+
});
720+
}
721+
});
722+
723+
let map = match Arc::into_inner(map)
724+
.expect("All other threads are done")
725+
.into_inner()
726+
{
727+
Ok(map) => map,
728+
Err(e) => e.into_inner(),
729+
};
730+
731+
map.into_values().collect()
732+
}
733+
644734
/// Calculates the number of errors when running the compiled binary and the feature binary on the
645735
/// repo specified with the specific configs.
646-
pub fn check_diff<P: AsRef<Path>>(
736+
pub fn check_diff_for_file<'repo, P: AsRef<Path>, F: AsRef<Path>>(
647737
runners: &CheckDiffRunners<impl CodeFormatter, impl CodeFormatter>,
648-
repo: P,
649-
repo_url: &str,
650-
) -> u8 {
651-
let mut errors: u8 = 0;
652-
let repo = repo.as_ref();
653-
let iter = search_for_rs_files(repo);
654-
for file in iter {
655-
let relative_path = file.strip_prefix(repo).unwrap_or(&file);
656-
let repo_name = get_repo_name(repo_url);
657-
658-
trace!(
659-
"Formatting '{0}' file {0}/{1}",
660-
repo_name,
661-
relative_path.display()
662-
);
663-
664-
match runners.create_diff(file.as_path()) {
665-
Ok(diff) => {
666-
if !diff.is_empty() {
667-
error!(
668-
"Diff found in '{0}' when formatting {0}/{1}\n{2}",
669-
repo_name,
670-
relative_path.display(),
671-
diff,
672-
);
673-
errors = errors.saturating_add(1);
674-
} else {
675-
trace!(
676-
"No diff found in '{0}' when formatting {0}/{1}",
677-
repo_name,
678-
relative_path.display(),
679-
)
680-
}
681-
}
682-
Err(CreateDiffError::MainRustfmtFailed(e)) => {
683-
debug!(
684-
"`main` rustfmt failed to format {}/{}\n{:?}",
685-
repo_name,
686-
relative_path.display(),
687-
e,
688-
);
689-
continue;
690-
}
691-
Err(CreateDiffError::FeatureRustfmtFailed(e)) => {
692-
debug!(
693-
"`feature` rustfmt failed to format {}/{}\n{:?}",
694-
repo_name,
695-
relative_path.display(),
696-
e,
697-
);
698-
continue;
699-
}
700-
Err(CreateDiffError::BothRustfmtFailed { src, feature }) => {
701-
debug!(
702-
"Both rustfmt binaries failed to format {}/{}\n{:?}\n{:?}",
738+
repo: &'repo Repository<P>,
739+
file: F,
740+
) -> Result<(), (Diff, F, &'repo Repository<P>)> {
741+
let relative_path = repo.relative_path(&file);
742+
let repo_name = repo.name();
743+
744+
trace!(
745+
"Formatting '{0}' file {0}/{1}",
746+
repo_name,
747+
relative_path.display()
748+
);
749+
750+
match runners.create_diff(file.as_ref()) {
751+
Ok(diff) => {
752+
if !diff.is_empty() {
753+
Err((diff, file, repo))
754+
} else {
755+
trace!(
756+
"No diff found in '{0}' when formatting {0}/{1}",
703757
repo_name,
704758
relative_path.display(),
705-
src,
706-
feature,
707759
);
708-
continue;
760+
Ok(())
709761
}
710762
}
763+
Err(CreateDiffError::MainRustfmtFailed(e)) => {
764+
debug!(
765+
"`main` rustfmt failed to format {}/{}\n{:?}",
766+
repo_name,
767+
relative_path.display(),
768+
e,
769+
);
770+
Ok(())
771+
}
772+
Err(CreateDiffError::FeatureRustfmtFailed(e)) => {
773+
debug!(
774+
"`feature` rustfmt failed to format {}/{}\n{:?}",
775+
repo_name,
776+
relative_path.display(),
777+
e,
778+
);
779+
Ok(())
780+
}
781+
Err(CreateDiffError::BothRustfmtFailed { src, feature }) => {
782+
debug!(
783+
"Both rustfmt binaries failed to format {}/{}\n{:?}\n{:?}",
784+
repo_name,
785+
relative_path.display(),
786+
src,
787+
feature,
788+
);
789+
Ok(())
790+
}
711791
}
712-
713-
errors
714792
}
715793

716794
/// parse out the repository name from a GitHub Repository name.
@@ -721,3 +799,60 @@ pub fn get_repo_name(git_url: &str) -> &str {
721799
.unwrap_or(("", strip_git_prefix));
722800
repo_name
723801
}
802+
803+
pub fn check_diff<'repo, P, F, M>(
804+
runners: &CheckDiffRunners<F, M>,
805+
repositories: &'repo [Repository<P>],
806+
worker_threads: std::num::NonZeroU8,
807+
) -> Vec<(Diff, PathBuf, &'repo Repository<P>)>
808+
where
809+
P: AsRef<Path> + Sync + Send,
810+
F: CodeFormatter + Sync,
811+
M: CodeFormatter + Sync,
812+
{
813+
let (tx, rx) = crossbeam_channel::unbounded();
814+
815+
let errors = std::thread::scope(|s| {
816+
// Spawn producer threads that find files to check
817+
for repo in repositories.iter() {
818+
let tx = tx.clone();
819+
s.spawn(move || {
820+
for file in search_for_rs_files(repo.path()) {
821+
let _ = tx.send((file, repo));
822+
}
823+
});
824+
}
825+
826+
// Drop the first `tx` we created. Now there's exactly one `tx` per producer thread so when
827+
// each producer thread finishes the receiving threads will start to get Err(RecvError)
828+
// when calling `rx.recv()` and they'll know to stop processing files.
829+
// When all scoped threads end we'll know we're done with processing and we can return
830+
// any errors we found to the caller.
831+
drop(tx);
832+
833+
let errors = Arc::new(Mutex::new(Vec::with_capacity(10)));
834+
835+
// spawn receiver threads used to process all files:
836+
for _ in 0..u8::from(worker_threads) {
837+
let errors = Arc::clone(&errors);
838+
let rx = rx.clone();
839+
s.spawn(move || {
840+
while let Ok((file, repo)) = rx.recv() {
841+
if let Err(e) = check_diff_for_file(runners, repo, file) {
842+
// Push errors to report on later
843+
errors.lock().unwrap().push(e);
844+
}
845+
}
846+
});
847+
}
848+
errors
849+
});
850+
851+
match Arc::into_inner(errors)
852+
.expect("All other threads are done")
853+
.into_inner()
854+
{
855+
Ok(e) => e,
856+
Err(e) => e.into_inner(),
857+
}
858+
}

0 commit comments

Comments
 (0)