1111from django .core .management .commands .makemigrations import Command as BaseCommand
1212from django .db import DEFAULT_DB_ALIAS , connections , router
1313from django .db .migrations .loader import MigrationLoader
14+ from git import InvalidGitRepositoryError , Repo
1415
1516from migration_fixer .utils import (
1617 fix_named_migration ,
1718 fix_numbered_migration ,
1819 no_translations ,
19- run_command ,
2020)
2121
2222
@@ -27,6 +27,11 @@ class Command(BaseCommand):
2727
2828 success_msg = "Successfully fixed migrations."
2929
30+ def __init__ (self , * args , repo = None , ** kwargs ):
31+ super ().__init__ (* args , ** kwargs )
32+ self .cwd = os .getcwd ()
33+ self .repo = repo or Repo .init (self .cwd )
34+
3035 def add_arguments (self , parser ):
3136 parser .add_argument (
3237 "--fix" ,
@@ -39,12 +44,19 @@ def add_arguments(self, parser):
3944 help = "The name of the default branch." ,
4045 default = "main" ,
4146 )
47+ parser .add_argument (
48+ "-f" ,
49+ "--force-update" ,
50+ help = "Force update the default branch." ,
51+ action = "store_true" ,
52+ )
4253 super ().add_arguments (parser )
4354
4455 @no_translations
4556 def handle (self , * app_labels , ** options ):
4657 self .merge = options ["merge" ]
4758 self .fix = options ["fix" ]
59+ self .force_update = options ["force_update" ]
4860 self .default_branch = options ["default_branch" ]
4961
5062 if self .fix :
@@ -53,68 +65,50 @@ def handle(self, *app_labels, **options):
5365 except CommandError as e :
5466 [message ] = e .args
5567 if "Conflicting migrations" in message :
56- (
57- git_setup_has_error ,
58- git_setup_output ,
59- git_setup_error ,
60- ) = run_command ("git status" )
68+ if self .verbosity >= 2 :
69+ self .stdout .write ("Verifying git repository..." )
70+
71+ try :
72+ self .repo .git_dir
73+ except InvalidGitRepositoryError :
74+ is_git_repo = False
75+ else :
76+ is_git_repo = True
6177
62- if not git_setup_has_error :
78+ if not is_git_repo :
6379 raise CommandError (
6480 self .style .ERROR (
65- f"VCS is not yet setup. "
66- "Please run (git init) "
67- f'\n "{ git_setup_output or git_setup_error } "'
81+ f"Git repository is not yet setup. "
82+ "Please run (git init) in "
83+ f'\n "{ self . cwd } "'
6884 )
6985 )
7086
71- (
72- get_current_branch_has_error ,
73- get_current_branch_output ,
74- get_current_branch_error ,
75- ) = run_command ("git branch --show-current" )
87+ if self .verbosity >= 2 :
88+ self .stdout .write ("Retrieving the current branch..." )
7689
77- if not get_current_branch_has_error :
78- raise CommandError (
79- self .style .ERROR (
80- f"Unable to determine the current branch: "
81- f'"{ get_current_branch_output or get_current_branch_error } "'
82- )
83- )
90+ current_branch = self .repo .active_branch .name
8491
85- pull_command = (
86- "git pull"
87- if get_current_branch_output == self .default_branch
88- else (
89- "git fetch origin "
90- f"{ self .default_branch } :{ self .default_branch } "
92+ if self .verbosity >= 2 :
93+ self .stdout .write (
94+ f"Fetching git remote origin changes on: { self .default_branch } "
9195 )
92- )
9396
94- # Pull the last commit
95- git_pull_has_error , git_pull_output , git_pull_error = run_command (
96- pull_command
97- )
97+ if current_branch == self .default_branch :
98+ self .repo .remotes [self .default_branch ].origin .pull ()
99+ else :
100+ for remote in self .repo .remotes :
101+ remote .fetch (self .default_branch , force = self .force_update )
98102
99- if not git_pull_has_error :
100- raise CommandError (
101- self .style .ERROR (
102- f"Error pulling branch ({ self .default_branch } ) changes: "
103- f'"{ git_pull_output or git_pull_error } "'
104- )
103+ if self .verbosity >= 2 :
104+ self .stdout .write (
105+ f"Retrieving the last commit sha on: { self .default_branch } "
105106 )
106107
107- head_sha_has_error , head_sha_output , head_sha_error = run_command (
108- f"git rev-parse { self . default_branch } "
109- )
108+ default_branch_commit = self . repo . commit ( self . default_branch )
109+
110+ current_commit = self . repo . commit ( current_branch )
110111
111- if not head_sha_has_error :
112- raise CommandError (
113- self .style .ERROR (
114- f"Error determining head sha on ({ self .default_branch } ): "
115- f'"{ head_sha_output or head_sha_error } "'
116- )
117- )
118112 # Load the current graph state. Pass in None for the connection so
119113 # the loader doesn't try to resolve replaced migrations from DB.
120114 loader = MigrationLoader (None , ignore_no_migrations = True )
@@ -147,13 +141,7 @@ def handle(self, *app_labels, **options):
147141 # hard if there are any and they don't want to merge
148142 conflicts = loader .detect_conflicts ()
149143
150- app_labels = app_labels or tuple (
151- app_label
152- for app_label in settings .INSTALLED_APPS
153- if app_label in conflicts
154- )
155-
156- for app_label in app_labels :
144+ for app_label in conflicts :
157145 conflict = conflicts .get (app_label )
158146 migration_module , _ = loader .migrations_module (app_label )
159147 migration_absolute_path = os .path .join (
@@ -164,33 +152,34 @@ def handle(self, *app_labels, **options):
164152 )
165153
166154 with migration_path :
167- (
168- get_changed_files_has_error ,
169- get_changed_files_output ,
170- get_changed_files_error ,
171- ) = run_command (
172- f"git diff --diff-filter=ACMUXTR --name-only { self .default_branch } "
173- )
174-
175- if not get_changed_files_has_error :
176- raise CommandError (
177- self .style .ERROR (
178- "Error retrieving changed files on "
179- f"({ self .default_branch } ): "
180- f'"{ get_changed_files_output or get_changed_files_error } "'
181- )
155+ if self .verbosity >= 2 :
156+ self .stdout .write (
157+ "Retrieving changed files between "
158+ f"the current branch and { self .default_branch } "
182159 )
160+
161+ diff_index = default_branch_commit .diff (current_commit )
162+
183163 # Files different on the current branch
184164 changed_files = [
185- fname
186- for fname in get_changed_files_output .split ("\n " )
187- if migration_absolute_path in fname
165+ diff .b_path
166+ for diff in diff_index
167+ if migration_absolute_path
168+ in getattr (diff .a_blob , "abspath" , "" )
169+ or migration_absolute_path
170+ in getattr (diff .b_blob , "abspath" , "" )
188171 ]
172+
189173 # Local migration
190174 local_filenames = [
191175 os .path .splitext (os .path .basename (p ))[0 ]
192176 for p in changed_files
193177 ]
178+ if self .verbosity >= 2 :
179+ self .stdout .write (
180+ f"Retrieving the last migration on: { self .default_branch } "
181+ )
182+
194183 last_remote = [
195184 fname
196185 for fname in conflict
@@ -218,6 +207,11 @@ def handle(self, *app_labels, **options):
218207 and len (seed_split ) > 1
219208 and str (seed_split [0 ]).isdigit ()
220209 ):
210+ if self .verbosity >= 2 :
211+ self .stdout .write (
212+ "Fixing numbered migration..."
213+ )
214+
221215 fix_numbered_migration (
222216 app_label = app_label ,
223217 migration_path = migration_path ,
@@ -226,6 +220,9 @@ def handle(self, *app_labels, **options):
226220 changed_files = changed_files ,
227221 )
228222 else :
223+ if self .verbosity >= 2 :
224+ self .stdout .write ("Fixing named migration..." )
225+
229226 fix_named_migration (
230227 app_label = app_label ,
231228 migration_path = migration_path ,
0 commit comments