Skip to content
This repository was archived by the owner on Oct 21, 2024. It is now read-only.

Commit 215b44f

Browse files
authored
feat: prevent overriding commits in a branch that were not made by rcmt (#439)
1 parent c54d998 commit 215b44f

File tree

10 files changed

+468
-11
lines changed

10 files changed

+468
-11
lines changed

rcmt/git.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
log: structlog.stdlib.BoundLogger = structlog.get_logger(package="git")
1616

1717

18+
class BranchModifiedError(RuntimeError):
19+
def __init__(self, checksums: list[str]):
20+
self.checksums = checksums
21+
22+
1823
class Git:
1924
def __init__(
2025
self,
@@ -38,6 +43,22 @@ def commit_changes(self, repo_dir: str, msg: str):
3843
git_repo.git.add(all=True)
3944
git_repo.index.commit(msg)
4045

46+
def _detect_modified_branch(
47+
self,
48+
merge_base: str,
49+
repo: git.Repo,
50+
) -> None:
51+
foreign_commits: list[str] = []
52+
for commit in repo.iter_commits():
53+
if commit.hexsha == merge_base:
54+
break
55+
56+
if commit.author.email != self.user_email:
57+
foreign_commits.append(commit.hexsha)
58+
59+
if len(foreign_commits) > 0:
60+
raise BranchModifiedError(foreign_commits)
61+
4162
@staticmethod
4263
def has_changes_origin(branch: str, repo_dir: str) -> bool:
4364
git_repo = git.Repo(path=repo_dir)
@@ -54,7 +75,9 @@ def has_changes_local(repo_dir: str) -> bool:
5475
git_repo = git.Repo(path=repo_dir)
5576
return len(git_repo.index.diff(None)) > 0 or len(git_repo.untracked_files) > 0
5677

57-
def prepare(self, repo: source.Repository, iteration: int = 0) -> Tuple[str, bool]:
78+
def prepare(
79+
self, repo: source.Repository, force_rebase: bool, iteration: int = 0
80+
) -> Tuple[str, bool]:
5881
"""
5982
1. Clone repository
6083
2. Checkout base branch
@@ -94,7 +117,7 @@ def prepare(self, repo: source.Repository, iteration: int = 0) -> Tuple[str, boo
94117
branch=repo.base_branch,
95118
)
96119
shutil.rmtree(checkout_dir)
97-
return self.prepare(repo, iteration=1)
120+
return self.prepare(force_rebase=force_rebase, iteration=1, repo=repo)
98121

99122
hash_before_pull = str(git_repo.head.commit)
100123
log.debug(
@@ -146,7 +169,18 @@ def prepare(self, repo: source.Repository, iteration: int = 0) -> Tuple[str, boo
146169

147170
log.debug("Checking out work branch", branch=self.branch_name)
148171
git_repo.heads[self.branch_name].checkout()
172+
if remote_branch is not None:
173+
log.debug("Pulling changes into work branch", branch=self.branch_name)
174+
# `rebase=True` to end up with a clean history.
175+
# `strategy_option="theirs"` to always prefer changes from the remote.
176+
# Commits by someone else will be preserved with this strategy and there
177+
# will be no conflict.
178+
git_repo.remotes["origin"].pull(rebase=True, strategy_option="theirs")
179+
149180
merge_base = git_repo.git.merge_base(repo.base_branch, self.branch_name)
181+
if force_rebase is False:
182+
self._detect_modified_branch(merge_base=merge_base, repo=git_repo)
183+
150184
log.debug("Resetting to merge base", branch=self.branch_name)
151185
git_repo.git.reset(merge_base, hard=True)
152186
log.debug("Rebasing onto work branch", branch=self.branch_name)

rcmt/rcmt.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import datetime
66
import shutil
77
from enum import Enum
8-
from typing import Iterator, Optional
8+
from typing import Any, Iterator, Optional
99

10+
import jinja2
1011
import structlog
1112
from git.exc import GitCommandError
1213

@@ -16,6 +17,24 @@
1617
log: structlog.stdlib.BoundLogger = structlog.get_logger()
1718

1819

20+
TEMPLATE_BRANCH_MODIFIED = jinja2.Template(
21+
""":warning: **This pull request has been modified.**
22+
23+
This is a safety mechanism to prevent rcmt from accidentally overriding custom commits.
24+
25+
rcmt will not be able to resolve merge conflicts with `{{ default_branch }}` automatically.
26+
It will not update this pull request or auto-merge it.
27+
28+
Check the box in the description of this PR to force a rebase. This will remove all commits not made by rcmt.
29+
30+
The commit(s) that modified the pull request:
31+
{% for checksum in checksums %}
32+
- {{ checksum }}
33+
{% endfor %}
34+
"""
35+
)
36+
37+
1938
class Options:
2039
def __init__(self, cfg: config.Config):
2140
self.config = cfg
@@ -75,8 +94,28 @@ def execute(
7594
)
7695
return RunResult.NO_CHANGES
7796

97+
force_rebase = self._has_rebase_checked(pr=pr_identifier, repo=repo)
7898
try:
79-
work_dir, has_conflict = self.git.prepare(repo)
99+
work_dir, has_conflict = self.git.prepare(
100+
force_rebase=force_rebase, repo=repo
101+
)
102+
if force_rebase is True:
103+
# It is likely that the branch was modified previously and a comment was
104+
# created to notify the user. Delete that comment.
105+
repo.delete_pr_comment_with_identifier(
106+
identifier="branch-modified", pr=pr_identifier
107+
)
108+
109+
except git.BranchModifiedError as e:
110+
log.warn("Branch contains commits not made by rcmt")
111+
ctx.repo.create_pr_comment_with_identifier(
112+
body=TEMPLATE_BRANCH_MODIFIED.render(
113+
checksums=e.checksums, default_branch=repo.base_branch
114+
),
115+
identifier="branch-modified",
116+
pr=pr_identifier,
117+
)
118+
return RunResult.NO_CHANGES
80119
except GitCommandError as e:
81120
# Catch any error raised by the git client, delete the repository and
82121
# initialize it again
@@ -86,7 +125,9 @@ def execute(
86125
)
87126
checkout_dir = self.git.checkout_dir(repo)
88127
shutil.rmtree(checkout_dir)
89-
work_dir, has_conflict = self.git.prepare(repo)
128+
work_dir, has_conflict = self.git.prepare(
129+
force_rebase=force_rebase, repo=repo
130+
)
90131

91132
apply_actions(ctx=ctx, task_=matcher, work_dir=work_dir)
92133
has_local_changes = self.git.has_changes_local(work_dir)
@@ -202,6 +243,14 @@ def execute(
202243

203244
return RunResult.NO_CHANGES
204245

246+
@staticmethod
247+
def _has_rebase_checked(pr: Any, repo: source.Repository) -> bool:
248+
if pr is None:
249+
return False
250+
251+
desc = repo.get_pr_body(pr)
252+
return "[x] If you want to rebase this PR" in desc
253+
205254

206255
def apply_actions(
207256
ctx: context.Context,

rcmt/source/github.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
import structlog
1515

1616
from ..log import SECRET_MASKER
17-
from .source import Base, PullRequest, Repository, add_credentials_to_url
17+
from .source import (
18+
Base,
19+
PullRequest,
20+
PullRequestComment,
21+
Repository,
22+
add_credentials_to_url,
23+
)
1824

1925
log = structlog.get_logger(source="github")
2026

@@ -57,6 +63,9 @@ def close_pull_request(
5763
pr.create_issue_comment(body=message)
5864
pr.edit(state="closed")
5965

66+
def create_pr_comment(self, body: str, pr: github.PullRequest.PullRequest) -> None:
67+
pr.create_issue_comment(body=body)
68+
6069
def create_pull_request(self, branch: str, pr: PullRequest):
6170
log.debug(
6271
"Creating pull request", base=self.base_branch, head=branch, repo=str(self)
@@ -76,6 +85,11 @@ def delete_branch(self, identifier: github.PullRequest.PullRequest) -> None:
7685
if self.repo.delete_branch_on_merge is False:
7786
self.repo.get_git_ref(ref=f"heads/{identifier.head.ref}").delete()
7887

88+
def delete_pr_comment(
89+
self, comment: PullRequestComment, pr: github.PullRequest.PullRequest
90+
) -> None:
91+
pr.get_issue_comment(comment.id).delete()
92+
7993
def find_pull_request(self, branch: str) -> Union[Any, None]:
8094
log.debug("Listing pull requests", repo=str(self))
8195
for pr in self.repo.get_pulls(state="all"):
@@ -98,6 +112,9 @@ def get_file(self, path: str) -> TextIO:
98112

99113
return io.StringIO(file.decoded_content.decode("utf-8"))
100114

115+
def get_pr_body(self, pr: github.PullRequest.PullRequest) -> str:
116+
return pr.body
117+
101118
def has_file(self, path: str) -> bool:
102119
try:
103120
tree = self.repo.get_git_tree(self.base_branch, True)
@@ -136,6 +153,15 @@ def is_pr_merged(self, pr: github.PullRequest.PullRequest) -> bool:
136153
def is_pr_open(self, pr: github.PullRequest.PullRequest) -> bool:
137154
return pr.state == "open"
138155

156+
def list_pr_comments(
157+
self, pr: github.PullRequest.PullRequest
158+
) -> Iterator[PullRequestComment]:
159+
if pr is None:
160+
return []
161+
162+
for issue in pr.get_issue_comments():
163+
yield PullRequestComment(body=issue.body, id=issue.id)
164+
139165
def merge_pull_request(self, pr: github.PullRequest.PullRequest) -> None:
140166
log.debug("Merging pull request", repo=str(self))
141167
pr.merge(commit_title="Auto-merge by rcmt")

rcmt/source/gitlab.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818
from gitlab.v4.objects.merge_requests import ProjectMergeRequest as GitlabMergeRequest
1919

2020
from ..log import SECRET_MASKER
21-
from .source import Base, PullRequest, Repository, add_credentials_to_url
21+
from .source import (
22+
Base,
23+
PullRequest,
24+
PullRequestComment,
25+
Repository,
26+
add_credentials_to_url,
27+
)
2228

2329
log: structlog.stdlib.BoundLogger = structlog.get_logger(source="gitlab")
2430

@@ -48,6 +54,9 @@ def close_pull_request(self, message: str, pr: GitlabMergeRequest) -> None:
4854
pr.state_event = "close"
4955
pr.save()
5056

57+
def create_pr_comment(self, body: str, pr: GitlabMergeRequest) -> None:
58+
pr.notes.create({"body": body})
59+
5160
def create_pull_request(self, branch: str, pr: PullRequest) -> None:
5261
log.debug(
5362
"Creating merge request", base=self.base_branch, head=branch, repo=str(self)
@@ -67,6 +76,11 @@ def delete_branch(self, identifier: GitlabMergeRequest) -> None:
6776
if identifier.should_remove_source_branch is not True:
6877
self._project.branches.get(id=identifier.source_branch, lazy=True).delete()
6978

79+
def delete_pr_comment(
80+
self, comment: PullRequestComment, pr: GitlabMergeRequest
81+
) -> None:
82+
pr.notes.delete(comment.id)
83+
7084
def find_pull_request(self, branch: str) -> Union[Any, None]:
7185
log.debug("Listing merge requests", repo=str(self))
7286
mrs = self._project.mergerequests.list(
@@ -98,6 +112,9 @@ def get_file(self, path: str) -> TextIO:
98112
except GitlabGetError:
99113
raise FileNotFoundError("file does not exist in repository")
100114

115+
def get_pr_body(self, mr: GitlabMergeRequest) -> str:
116+
return mr.description
117+
101118
def has_file(self, path: str) -> bool:
102119
directory = os.path.dirname(path)
103120
try:
@@ -156,6 +173,13 @@ def is_pr_merged(self, mr: GitlabMergeRequest) -> bool:
156173
def is_pr_open(self, mr: GitlabMergeRequest) -> bool:
157174
return mr.state == "opened"
158175

176+
def list_pr_comments(self, mr: GitlabMergeRequest) -> Iterator[PullRequestComment]:
177+
if mr is None:
178+
return []
179+
180+
for note in mr.notes.list(iterator=True):
181+
yield PullRequestComment(body=note.body, id=note.id)
182+
159183
def merge_pull_request(self, identifier: GitlabMergeRequest):
160184
log.debug("Merging merge request", repo=str(self), id=identifier.get_id())
161185
identifier.merge()

rcmt/source/source.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import datetime
66
import urllib.parse
7+
from dataclasses import dataclass
78
from typing import Any, Generator, Iterator, Optional, TextIO, Union
89

910
import humanize
@@ -100,6 +101,8 @@ def body(self) -> str:
100101
else:
101102
body += "**Ignore:** This PR will be recreated if closed. \n"
102103

104+
body += "\n---\n- [ ] If you want to rebase this PR, check this box"
105+
103106
body += """
104107
---
105108
@@ -116,6 +119,12 @@ def title(self) -> str:
116119
return f"{self.title_prefix} {self.title_body} {self.title_suffix}".strip()
117120

118121

122+
@dataclass
123+
class PullRequestComment:
124+
body: str
125+
id: Any
126+
127+
119128
class Repository:
120129
"""
121130
Repository provides all methods needed to interact with a single repository of a
@@ -169,6 +178,25 @@ def close_pull_request(self, message: str, pr: Any) -> None:
169178
"class does not implement Repository.close_pull_request()"
170179
)
171180

181+
def create_pr_comment(self, body: str, pr: Any) -> None:
182+
raise NotImplementedError(
183+
"class does not implement Repository.create_pr_comment()"
184+
)
185+
186+
def create_pr_comment_with_identifier(
187+
self, body: str, identifier: str, pr: Any
188+
) -> None:
189+
if identifier == "":
190+
raise RuntimeError("identifier cannot be empty")
191+
192+
prefix = f"<!-- rcmt::{identifier} -->"
193+
for comment in self.list_pr_comments(pr):
194+
if comment.body.startswith(prefix):
195+
return None
196+
197+
body = f"{prefix}\n{body}"
198+
self.create_pr_comment(body=body, pr=pr)
199+
172200
def create_pull_request(self, branch: str, pr: PullRequest) -> None:
173201
"""
174202
Creates a pull request for the given branch.
@@ -181,6 +209,21 @@ def create_pull_request(self, branch: str, pr: PullRequest) -> None:
181209
"class does not implement Repository.create_pull_request()"
182210
)
183211

212+
def delete_pr_comment(self, comment: PullRequestComment, pr: Any) -> None:
213+
raise NotImplementedError(
214+
"class does not implement Repository.delete_comment()"
215+
)
216+
217+
def delete_pr_comment_with_identifier(self, identifier: str, pr: Any) -> None:
218+
if identifier == "":
219+
raise RuntimeError("identifier cannot be empty")
220+
221+
prefix = f"<!-- rcmt::{identifier} -->"
222+
for comment in self.list_pr_comments(pr):
223+
if comment.body.startswith(prefix):
224+
self.delete_pr_comment(comment=comment, pr=pr)
225+
return
226+
184227
def delete_branch(self, identifier: Any) -> None:
185228
raise NotImplementedError(
186229
"class does not implement Repository.delete_pull_request()"
@@ -210,6 +253,11 @@ def full_name(self) -> str:
210253
def get_file(self, path: str) -> TextIO:
211254
raise NotImplementedError("class does not implement Repository.has_file()")
212255

256+
def get_pr_body(self, pr_identifier: Any) -> str:
257+
raise NotImplementedError(
258+
"class does not implement Repository.get_pr_description()"
259+
)
260+
213261
def has_file(self, path: str) -> bool:
214262
"""
215263
Checks if a file exists in a repository.
@@ -272,6 +320,11 @@ def is_pr_open(self, identifier: Any) -> bool:
272320
"""
273321
raise NotImplementedError("class does not implement Repository.is_pr_open()")
274322

323+
def list_pr_comments(self, pr: Any) -> Iterator[PullRequestComment]:
324+
raise NotImplementedError(
325+
"class does not implement Repository.list_pr_comments()"
326+
)
327+
275328
def merge_pull_request(self, identifier: Any) -> None:
276329
"""
277330
Merges a pull request.

0 commit comments

Comments
 (0)