99import subprocess
1010from abc import ABC , abstractmethod
1111from shutil import which
12+ from typing import List , Optional
1213
1314from taskgraph .util .path import ancestors
1415
@@ -34,7 +35,7 @@ def __init__(self, path):
3435
3536 self ._env = os .environ .copy ()
3637
37- def run (self , * args : str , ** kwargs ):
38+ def run (self , * args : str , ** kwargs ) -> str :
3839 return_codes = kwargs .pop ("return_codes" , [])
3940 cmd = (self .binary ,) + args
4041
@@ -63,17 +64,17 @@ def head_rev(self) -> str:
6364
6465 @property
6566 @abstractmethod
66- def base_rev (self ):
67+ def base_rev (self ) -> str :
6768 """Hash of revision the current topic branch is based on."""
6869
6970 @property
7071 @abstractmethod
71- def branch (self ):
72+ def branch (self ) -> Optional [ str ] :
7273 """Current branch or bookmark the checkout has active."""
7374
7475 @property
7576 @abstractmethod
76- def all_remote_names (self ):
77+ def all_remote_names (self ) -> List [ str ] :
7778 """Name of all configured remote repositories."""
7879
7980 @property
@@ -85,10 +86,10 @@ def default_remote_name(self) -> str:
8586
8687 @property
8788 @abstractmethod
88- def remote_name (self ):
89+ def remote_name (self ) -> str :
8990 """Name of the remote repository."""
9091
91- def _get_most_suitable_remote (self , remote_instructions ):
92+ def _get_most_suitable_remote (self , remote_instructions ) -> str :
9293 remotes = self .all_remote_names
9394
9495 # in case all_remote_names raised a RuntimeError
@@ -113,19 +114,34 @@ def _get_most_suitable_remote(self, remote_instructions):
113114
114115 @property
115116 @abstractmethod
116- def default_branch (self ):
117+ def default_branch (self ) -> str :
117118 """Name of the default branch."""
118119
119120 @abstractmethod
120- def get_url (self , remote = None ) :
121+ def get_url (self , remote : Optional [ str ]) -> str :
121122 """Get URL of the upstream repository."""
122123
123124 @abstractmethod
124- def get_commit_message (self , revision = None ) :
125+ def get_commit_message (self , revision : Optional [ str ]) -> str :
125126 """Commit message of specified revision or current commit."""
126127
127128 @abstractmethod
128- def get_changed_files (self , diff_filter , mode = "unstaged" , rev = None , base_rev = None ):
129+ def get_tracked_files (self , * paths : str , rev : Optional [str ] = None ) -> List [str ]:
130+ """Return list of tracked files.
131+
132+ ``*paths`` are path specifiers to limit results to.
133+ ``rev`` is a revision specifier at which to retrieve the files.
134+ Defaults to the parent of the working copy if unspecified.
135+ """
136+
137+ @abstractmethod
138+ def get_changed_files (
139+ self ,
140+ diff_filter : Optional [str ],
141+ mode : Optional [str ],
142+ rev : Optional [str ],
143+ base_rev : Optional [str ],
144+ ) -> List [str ]:
129145 """Return a list of files that are changed in:
130146 * either this repository's working copy,
131147 * or at a given revision (``rev``)
@@ -152,7 +168,7 @@ def get_changed_files(self, diff_filter, mode="unstaged", rev=None, base_rev=Non
152168 """
153169
154170 @abstractmethod
155- def get_outgoing_files (self , diff_filter , upstream ) :
171+ def get_outgoing_files (self , diff_filter : str , upstream : str ) -> List [ str ] :
156172 """Return a list of changed files compared to upstream.
157173
158174 ``diff_filter`` works the same as `get_changed_files`.
@@ -162,7 +178,9 @@ def get_outgoing_files(self, diff_filter, upstream):
162178 """
163179
164180 @abstractmethod
165- def working_directory_clean (self , untracked = False , ignored = False ):
181+ def working_directory_clean (
182+ self , untracked : Optional [bool ] = False , ignored : Optional [bool ] = False
183+ ) -> bool :
166184 """Determine if the working directory is free of modifications.
167185
168186 Returns True if the working directory does not have any file
@@ -174,19 +192,19 @@ def working_directory_clean(self, untracked=False, ignored=False):
174192 """
175193
176194 @abstractmethod
177- def update (self , ref ) :
195+ def update (self , ref : str ) -> None :
178196 """Update the working directory to the specified reference."""
179197
180198 @abstractmethod
181- def find_latest_common_revision (self , base_ref_or_rev , head_rev ) :
199+ def find_latest_common_revision (self , base_ref_or_rev : str , head_rev : str ) -> str :
182200 """Find the latest revision that is common to both the given
183201 ``head_rev`` and ``base_ref_or_rev``.
184202
185203 If no common revision exists, ``Repository.NULL_REVISION`` will
186204 be returned."""
187205
188206 @abstractmethod
189- def does_revision_exist_locally (self , revision ) :
207+ def does_revision_exist_locally (self , revision : str ) -> bool :
190208 """Check whether this revision exists in the local repository.
191209
192210 If this function returns an unexpected value, then make sure
@@ -243,7 +261,8 @@ def default_branch(self):
243261 # https://www.mercurial-scm.org/wiki/StandardBranching#Don.27t_use_a_name_other_than_default_for_your_main_development_branch
244262 return "default"
245263
246- def get_url (self , remote = "default" ):
264+ def get_url (self , remote = None ):
265+ remote = remote or "default"
247266 return self .run ("path" , "-T" , "{url}" , remote ).strip ()
248267
249268 def get_commit_message (self , revision = None ):
@@ -270,9 +289,12 @@ def _files_template(self, diff_filter):
270289 template += "{file_mods % '{file}\\ n'}"
271290 return template
272291
273- def get_changed_files (
274- self , diff_filter = "ADM" , mode = "unstaged" , rev = None , base_rev = None
275- ):
292+ def get_tracked_files (self , * paths , rev = None ):
293+ rev = rev or "."
294+ return self .run ("files" , "-r" , rev , * paths ).splitlines ()
295+
296+ def get_changed_files (self , diff_filter = None , mode = None , rev = None , base_rev = None ):
297+ diff_filter = diff_filter or "ADM"
276298 if rev is None :
277299 if base_rev is not None :
278300 raise ValueError ("Cannot specify `base_rev` without `rev`" )
@@ -315,7 +337,7 @@ def working_directory_clean(self, untracked=False, ignored=False):
315337 return not len (self .run (* args ).strip ())
316338
317339 def update (self , ref ):
318- return self .run ("update" , "--check" , ref )
340+ self .run ("update" , "--check" , ref )
319341
320342 def find_latest_common_revision (self , base_ref_or_rev , head_rev ):
321343 ancestor = self .run (
@@ -445,16 +467,21 @@ def _guess_default_branch(self):
445467
446468 raise RuntimeError (f"Unable to find default branch. Got: { branches } " )
447469
448- def get_url (self , remote = "origin" ):
470+ def get_url (self , remote = None ):
471+ remote = remote or "origin"
449472 return self .run ("remote" , "get-url" , remote ).strip ()
450473
451474 def get_commit_message (self , revision = None ):
452475 revision = revision or "HEAD"
453476 return self .run ("log" , "-n1" , "--format=%B" , revision )
454477
455- def get_changed_files (
456- self , diff_filter = "ADM" , mode = "unstaged" , rev = None , base_rev = None
457- ):
478+ def get_tracked_files (self , * paths , rev = None ):
479+ rev = rev or "HEAD"
480+ return self .run ("ls-tree" , "-r" , "--name-only" , rev , * paths ).splitlines ()
481+
482+ def get_changed_files (self , diff_filter = None , mode = None , rev = None , base_rev = None ):
483+ diff_filter = diff_filter or "ADM"
484+ mode = mode or "unstaged"
458485 assert all (f .lower () in self ._valid_diff_filter for f in diff_filter )
459486
460487 if rev is None :
0 commit comments