Skip to content

Commit 9f2206c

Browse files
committed
refactor: add type annotations to util/vcs.py and reformat
1 parent 4d8e75b commit 9f2206c

File tree

1 file changed

+34
-24
lines changed

1 file changed

+34
-24
lines changed

src/taskgraph/util/vcs.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import subprocess
1010
from abc import ABC, abstractmethod
1111
from shutil import which
12+
from typing import List, Optional
1213

1314
from taskgraph.util.path import ancestors
1415

@@ -34,7 +35,7 @@ def __init__(self, path):
3435

3536
self._env = os.environ.copy()
3637

37-
def run(self, *args: str, **kwargs):
38+
def run(self, *args: str, **kwargs) -> str:
3839
return_codes = kwargs.pop("return_codes", [])
3940
cmd = (self.binary,) + args
4041

@@ -63,17 +64,17 @@ def head_rev(self) -> str:
6364

6465
@property
6566
@abstractmethod
66-
def base_rev(self):
67+
def base_rev(self) -> str:
6768
"""Hash of revision the current topic branch is based on."""
6869

6970
@property
7071
@abstractmethod
71-
def branch(self):
72+
def branch(self) -> str | None:
7273
"""Current branch or bookmark the checkout has active."""
7374

7475
@property
7576
@abstractmethod
76-
def all_remote_names(self):
77+
def all_remote_names(self) -> List[str]:
7778
"""Name of all configured remote repositories."""
7879

7980
@property
@@ -85,10 +86,10 @@ def default_remote_name(self) -> str:
8586

8687
@property
8788
@abstractmethod
88-
def remote_name(self):
89+
def remote_name(self) -> str:
8990
"""Name of the remote repository."""
9091

91-
def _get_most_suitable_remote(self, remote_instructions):
92+
def _get_most_suitable_remote(self, remote_instructions) -> str:
9293
remotes = self.all_remote_names
9394

9495
# in case all_remote_names raised a RuntimeError
@@ -113,19 +114,25 @@ def _get_most_suitable_remote(self, remote_instructions):
113114

114115
@property
115116
@abstractmethod
116-
def default_branch(self):
117+
def default_branch(self) -> str:
117118
"""Name of the default branch."""
118119

119120
@abstractmethod
120-
def get_url(self, remote=None):
121+
def get_url(self, remote: Optional[str]) -> str:
121122
"""Get URL of the upstream repository."""
122123

123124
@abstractmethod
124-
def get_commit_message(self, revision=None):
125+
def get_commit_message(self, revision: Optional[str]) -> str:
125126
"""Commit message of specified revision or current commit."""
126127

127128
@abstractmethod
128-
def get_changed_files(self, diff_filter, mode="unstaged", rev=None, base_rev=None):
129+
def get_changed_files(
130+
self,
131+
diff_filter: Optional[str],
132+
mode: Optional[str],
133+
rev: Optional[str],
134+
base_rev: Optional[str],
135+
) -> List[str]:
129136
"""Return a list of files that are changed in:
130137
* either this repository's working copy,
131138
* or at a given revision (``rev``)
@@ -152,7 +159,7 @@ def get_changed_files(self, diff_filter, mode="unstaged", rev=None, base_rev=Non
152159
"""
153160

154161
@abstractmethod
155-
def get_outgoing_files(self, diff_filter, upstream):
162+
def get_outgoing_files(self, diff_filter: str, upstream: str) -> List[str]:
156163
"""Return a list of changed files compared to upstream.
157164
158165
``diff_filter`` works the same as `get_changed_files`.
@@ -162,7 +169,9 @@ def get_outgoing_files(self, diff_filter, upstream):
162169
"""
163170

164171
@abstractmethod
165-
def working_directory_clean(self, untracked=False, ignored=False):
172+
def working_directory_clean(
173+
self, untracked: Optional[bool] = False, ignored: Optional[bool] = False
174+
) -> bool:
166175
"""Determine if the working directory is free of modifications.
167176
168177
Returns True if the working directory does not have any file
@@ -174,19 +183,19 @@ def working_directory_clean(self, untracked=False, ignored=False):
174183
"""
175184

176185
@abstractmethod
177-
def update(self, ref):
186+
def update(self, ref: str) -> None:
178187
"""Update the working directory to the specified reference."""
179188

180189
@abstractmethod
181-
def find_latest_common_revision(self, base_ref_or_rev, head_rev):
190+
def find_latest_common_revision(self, base_ref_or_rev: str, head_rev: str) -> str:
182191
"""Find the latest revision that is common to both the given
183192
``head_rev`` and ``base_ref_or_rev``.
184193
185194
If no common revision exists, ``Repository.NULL_REVISION`` will
186195
be returned."""
187196

188197
@abstractmethod
189-
def does_revision_exist_locally(self, revision):
198+
def does_revision_exist_locally(self, revision: str) -> bool:
190199
"""Check whether this revision exists in the local repository.
191200
192201
If this function returns an unexpected value, then make sure
@@ -243,7 +252,8 @@ def default_branch(self):
243252
# https://www.mercurial-scm.org/wiki/StandardBranching#Don.27t_use_a_name_other_than_default_for_your_main_development_branch
244253
return "default"
245254

246-
def get_url(self, remote="default"):
255+
def get_url(self, remote=None):
256+
remote = remote or "default"
247257
return self.run("path", "-T", "{url}", remote).strip()
248258

249259
def get_commit_message(self, revision=None):
@@ -270,9 +280,8 @@ def _files_template(self, diff_filter):
270280
template += "{file_mods % '{file}\\n'}"
271281
return template
272282

273-
def get_changed_files(
274-
self, diff_filter="ADM", mode="unstaged", rev=None, base_rev=None
275-
):
283+
def get_changed_files(self, diff_filter=None, mode=None, rev=None, base_rev=None):
284+
diff_filter = diff_filter or "ADM"
276285
if rev is None:
277286
if base_rev is not None:
278287
raise ValueError("Cannot specify `base_rev` without `rev`")
@@ -315,7 +324,7 @@ def working_directory_clean(self, untracked=False, ignored=False):
315324
return not len(self.run(*args).strip())
316325

317326
def update(self, ref):
318-
return self.run("update", "--check", ref)
327+
self.run("update", "--check", ref)
319328

320329
def find_latest_common_revision(self, base_ref_or_rev, head_rev):
321330
ancestor = self.run(
@@ -445,16 +454,17 @@ def _guess_default_branch(self):
445454

446455
raise RuntimeError(f"Unable to find default branch. Got: {branches}")
447456

448-
def get_url(self, remote="origin"):
457+
def get_url(self, remote=None):
458+
remote = remote or "origin"
449459
return self.run("remote", "get-url", remote).strip()
450460

451461
def get_commit_message(self, revision=None):
452462
revision = revision or "HEAD"
453463
return self.run("log", "-n1", "--format=%B", revision)
454464

455-
def get_changed_files(
456-
self, diff_filter="ADM", mode="unstaged", rev=None, base_rev=None
457-
):
465+
def get_changed_files(self, diff_filter=None, mode=None, rev=None, base_rev=None):
466+
diff_filter = diff_filter or "ADM"
467+
mode = mode or "unstaged"
458468
assert all(f.lower() in self._valid_diff_filter for f in diff_filter)
459469

460470
if rev is None:

0 commit comments

Comments
 (0)