11use std:: borrow:: Cow ;
2+ use std:: collections:: HashMap ;
23use std:: env;
34use std:: fmt:: { Debug , Display } ;
45use std:: io:: { self , Write } ;
56use std:: path:: { Path , PathBuf } ;
67use std:: process:: { Command , Stdio } ;
78use std:: str:: FromStr ;
8- use tracing:: { debug, error, info, trace} ;
9+ use std:: sync:: { Arc , Mutex } ;
10+ use tempfile:: tempdir;
11+ use tracing:: { debug, info, trace, warn} ;
912use walkdir:: WalkDir ;
1013
1114#[ derive( Debug , Clone , Copy ) ]
@@ -411,6 +414,46 @@ fn create_config_arg<T: AsRef<str>>(configs: Option<&[T]>) -> Cow<'static, str>
411414
412415 Cow :: Owned ( result)
413416}
417+
418+ pub struct Repository < P > {
419+ /// Name of the repository
420+ name : String ,
421+ /// Path to the repository on the local file system
422+ dir_path : P ,
423+ }
424+
425+ impl < P > Repository < P > {
426+ /// Initialize a new Repository
427+ pub fn new ( git_url : & str , dir_path : P ) -> Self {
428+ let name = get_repo_name ( git_url) . to_string ( ) ;
429+ Self { name, dir_path }
430+ }
431+
432+ /// Get the `name` of the repository
433+ pub fn name ( & self ) -> & str {
434+ & self . name
435+ }
436+
437+ /// Get the absolute path to where this repository was cloned
438+ pub fn path ( & self ) -> & Path
439+ where
440+ P : AsRef < Path > ,
441+ {
442+ self . dir_path . as_ref ( )
443+ }
444+
445+ /// Get the relative path of a file contained in this repository
446+ pub fn relative_path < ' f , F > ( & self , file : & ' f F ) -> & ' f Path
447+ where
448+ P : AsRef < Path > ,
449+ F : AsRef < Path > ,
450+ {
451+ file. as_ref ( )
452+ . strip_prefix ( self . dir_path . as_ref ( ) )
453+ . unwrap_or ( file. as_ref ( ) )
454+ }
455+ }
456+
414457/// Clone a git repository
415458///
416459/// Parameters:
@@ -641,76 +684,111 @@ pub fn search_for_rs_files(repo: &Path) -> impl Iterator<Item = PathBuf> {
641684 } )
642685}
643686
687+ /// Encapsulate the logic used to clone repositories for the diff check
688+ pub fn clone_repositories_for_diff_check (
689+ repositories : & [ & str ] ,
690+ ) -> Vec < Repository < tempfile:: TempDir > > {
691+ // Use a Hashmap to deduplicate any repositories
692+ let map = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
693+
694+ std:: thread:: scope ( |s| {
695+ for url in repositories {
696+ let map = Arc :: clone ( & map) ;
697+
698+ s. spawn ( move || {
699+ let repo_name = get_repo_name ( url) ;
700+ info ! ( "Processing repo: {repo_name}" ) ;
701+ let Ok ( tmp_dir) = tempdir ( ) else {
702+ warn ! (
703+ "Failed to create a tempdir for {}. Can't check formatting diff for {}" ,
704+ & url, repo_name
705+ ) ;
706+ return ;
707+ } ;
708+
709+ let Ok ( _) = clone_git_repo ( url, tmp_dir. path ( ) ) else {
710+ warn ! (
711+ "Failed to clone repo {}. Can't check formatting diff for {}" ,
712+ & url, repo_name
713+ ) ;
714+ return ;
715+ } ;
716+
717+ let repo = Repository :: new ( url, tmp_dir) ;
718+ map. lock ( ) . unwrap ( ) . insert ( repo_name. to_string ( ) , repo) ;
719+ } ) ;
720+ }
721+ } ) ;
722+
723+ let map = match Arc :: into_inner ( map)
724+ . expect ( "All other threads are done" )
725+ . into_inner ( )
726+ {
727+ Ok ( map) => map,
728+ Err ( e) => e. into_inner ( ) ,
729+ } ;
730+
731+ map. into_values ( ) . collect ( )
732+ }
733+
644734/// Calculates the number of errors when running the compiled binary and the feature binary on the
645735/// repo specified with the specific configs.
646- pub fn check_diff < P : AsRef < Path > > (
736+ pub fn check_diff_for_file < ' repo , P : AsRef < Path > , F : AsRef < Path > > (
647737 runners : & CheckDiffRunners < impl CodeFormatter , impl CodeFormatter > ,
648- repo : P ,
649- repo_url : & str ,
650- ) -> u8 {
651- let mut errors: u8 = 0 ;
652- let repo = repo. as_ref ( ) ;
653- let iter = search_for_rs_files ( repo) ;
654- for file in iter {
655- let relative_path = file. strip_prefix ( repo) . unwrap_or ( & file) ;
656- let repo_name = get_repo_name ( repo_url) ;
657-
658- trace ! (
659- "Formatting '{0}' file {0}/{1}" ,
660- repo_name,
661- relative_path. display( )
662- ) ;
663-
664- match runners. create_diff ( file. as_path ( ) ) {
665- Ok ( diff) => {
666- if !diff. is_empty ( ) {
667- error ! (
668- "Diff found in '{0}' when formatting {0}/{1}\n {2}" ,
669- repo_name,
670- relative_path. display( ) ,
671- diff,
672- ) ;
673- errors = errors. saturating_add ( 1 ) ;
674- } else {
675- trace ! (
676- "No diff found in '{0}' when formatting {0}/{1}" ,
677- repo_name,
678- relative_path. display( ) ,
679- )
680- }
681- }
682- Err ( CreateDiffError :: MainRustfmtFailed ( e) ) => {
683- debug ! (
684- "`main` rustfmt failed to format {}/{}\n {:?}" ,
685- repo_name,
686- relative_path. display( ) ,
687- e,
688- ) ;
689- continue ;
690- }
691- Err ( CreateDiffError :: FeatureRustfmtFailed ( e) ) => {
692- debug ! (
693- "`feature` rustfmt failed to format {}/{}\n {:?}" ,
694- repo_name,
695- relative_path. display( ) ,
696- e,
697- ) ;
698- continue ;
699- }
700- Err ( CreateDiffError :: BothRustfmtFailed { src, feature } ) => {
701- debug ! (
702- "Both rustfmt binaries failed to format {}/{}\n {:?}\n {:?}" ,
738+ repo : & ' repo Repository < P > ,
739+ file : F ,
740+ ) -> Result < ( ) , ( Diff , F , & ' repo Repository < P > ) > {
741+ let relative_path = repo. relative_path ( & file) ;
742+ let repo_name = repo. name ( ) ;
743+
744+ trace ! (
745+ "Formatting '{0}' file {0}/{1}" ,
746+ repo_name,
747+ relative_path. display( )
748+ ) ;
749+
750+ match runners. create_diff ( file. as_ref ( ) ) {
751+ Ok ( diff) => {
752+ if !diff. is_empty ( ) {
753+ Err ( ( diff, file, repo) )
754+ } else {
755+ trace ! (
756+ "No diff found in '{0}' when formatting {0}/{1}" ,
703757 repo_name,
704758 relative_path. display( ) ,
705- src,
706- feature,
707759 ) ;
708- continue ;
760+ Ok ( ( ) )
709761 }
710762 }
763+ Err ( CreateDiffError :: MainRustfmtFailed ( e) ) => {
764+ debug ! (
765+ "`main` rustfmt failed to format {}/{}\n {:?}" ,
766+ repo_name,
767+ relative_path. display( ) ,
768+ e,
769+ ) ;
770+ Ok ( ( ) )
771+ }
772+ Err ( CreateDiffError :: FeatureRustfmtFailed ( e) ) => {
773+ debug ! (
774+ "`feature` rustfmt failed to format {}/{}\n {:?}" ,
775+ repo_name,
776+ relative_path. display( ) ,
777+ e,
778+ ) ;
779+ Ok ( ( ) )
780+ }
781+ Err ( CreateDiffError :: BothRustfmtFailed { src, feature } ) => {
782+ debug ! (
783+ "Both rustfmt binaries failed to format {}/{}\n {:?}\n {:?}" ,
784+ repo_name,
785+ relative_path. display( ) ,
786+ src,
787+ feature,
788+ ) ;
789+ Ok ( ( ) )
790+ }
711791 }
712-
713- errors
714792}
715793
716794/// parse out the repository name from a GitHub Repository name.
@@ -721,3 +799,60 @@ pub fn get_repo_name(git_url: &str) -> &str {
721799 . unwrap_or ( ( "" , strip_git_prefix) ) ;
722800 repo_name
723801}
802+
803+ pub fn check_diff < ' repo , P , F , M > (
804+ runners : & CheckDiffRunners < F , M > ,
805+ repositories : & ' repo [ Repository < P > ] ,
806+ worker_threads : std:: num:: NonZeroU8 ,
807+ ) -> Vec < ( Diff , PathBuf , & ' repo Repository < P > ) >
808+ where
809+ P : AsRef < Path > + Sync + Send ,
810+ F : CodeFormatter + Sync ,
811+ M : CodeFormatter + Sync ,
812+ {
813+ let ( tx, rx) = crossbeam_channel:: unbounded ( ) ;
814+
815+ let errors = std:: thread:: scope ( |s| {
816+ // Spawn producer threads that find files to check
817+ for repo in repositories. iter ( ) {
818+ let tx = tx. clone ( ) ;
819+ s. spawn ( move || {
820+ for file in search_for_rs_files ( repo. path ( ) ) {
821+ let _ = tx. send ( ( file, repo) ) ;
822+ }
823+ } ) ;
824+ }
825+
826+ // Drop the first `tx` we created. Now there's exactly one `tx` per producer thread so when
827+ // each producer thread finishes the receiving threads will start to get Err(RecvError)
828+ // when calling `rx.recv()` and they'll know to stop processing files.
829+ // When all scoped threads end we'll know we're done with processing and we can return
830+ // any errors we found to the caller.
831+ drop ( tx) ;
832+
833+ let errors = Arc :: new ( Mutex :: new ( Vec :: with_capacity ( 10 ) ) ) ;
834+
835+ // spawn receiver threads used to process all files:
836+ for _ in 0 ..u8:: from ( worker_threads) {
837+ let errors = Arc :: clone ( & errors) ;
838+ let rx = rx. clone ( ) ;
839+ s. spawn ( move || {
840+ while let Ok ( ( file, repo) ) = rx. recv ( ) {
841+ if let Err ( e) = check_diff_for_file ( runners, repo, file) {
842+ // Push errors to report on later
843+ errors. lock ( ) . unwrap ( ) . push ( e) ;
844+ }
845+ }
846+ } ) ;
847+ }
848+ errors
849+ } ) ;
850+
851+ match Arc :: into_inner ( errors)
852+ . expect ( "All other threads are done" )
853+ . into_inner ( )
854+ {
855+ Ok ( e) => e,
856+ Err ( e) => e. into_inner ( ) ,
857+ }
858+ }
0 commit comments