2828import os
2929from shutil import rmtree
3030from pathlib import Path
31+ from typing import Tuple , Union
3132
3233from lib .model .smartplugin import SmartPlugin
3334from lib .shyaml import yaml_load
3738from github import Auth
3839from github import Github
3940from git import Repo
41+ from git .exc import GitCommandError
4042
4143
4244#
4648class GitHubHelper (object ):
4749 """ Helper class for handling the GitHub API """
4850
49- def loggerr (self , msg ):
51+ def loggerr (self , msg : str ):
5052 """ log error message and raise GPError to signal WebIf """
5153
5254 # TODO: this need to be reworked if WebIf errors should be displayed in German or translated
5355 self .logger .error (msg )
5456 raise GPError (msg )
5557
56- def __init__ (self , dt , logger , repo = 'plugins' , apikey = '' , auth = None , ** kwargs ):
58+ def __init__ (self , dt , logger , repo : str = 'plugins' , apikey : str = '' , auth : Union [ Auth . Token , None ] = None , ** kwargs ):
5759 self .dt = dt
5860 self .logger = logger
5961 self .apikey = apikey
@@ -133,7 +135,7 @@ def get_rate_limit(self):
133135
134136 return [allow , remain , backoff ]
135137
136- def get_repo (self , user , repo ):
138+ def get_repo (self , user : str , repo : str ):
137139 if not self ._github :
138140 self .login ()
139141
@@ -149,7 +151,7 @@ def set_repo(self) -> bool:
149151 self .git_repo = self .get_repo ('smarthomeNG' , self .repo )
150152 return True
151153
152- def get_pulls (self , fetch = False ) -> bool :
154+ def get_pulls (self , fetch : bool = False ) -> bool :
153155 if not self ._github :
154156 self .login ()
155157
@@ -180,7 +182,7 @@ def get_pulls(self, fetch=False) -> bool:
180182
181183 return True
182184
183- def get_forks (self , fetch = False ) -> bool :
185+ def get_forks (self , fetch : bool = False ) -> bool :
184186 if not self ._github :
185187 self .login ()
186188
@@ -202,7 +204,7 @@ def get_forks(self, fetch=False) -> bool:
202204
203205 return True
204206
205- def get_branches_from (self , fork = None , owner = '' , fetch = False ) -> dict :
207+ def get_branches_from (self , fork : Union [ Repo , None ] = None , owner : str = '' , fetch : bool = False ) -> dict :
206208
207209 if fork is None and owner :
208210 try :
@@ -233,7 +235,7 @@ def get_branches_from(self, fork=None, owner='', fetch=False) -> dict:
233235 self .forks [fork .owner .login ]['branches' ] = b_list
234236 return b_list
235237
236- def get_plugins_from (self , fork = None , owner = '' , branch = '' , fetch = False ) -> list :
238+ def get_plugins_from (self , fork : Union [ Repo , None ] = None , owner : str = '' , branch : str = '' , fetch : bool = False ) -> list :
237239
238240 if not branch :
239241 return []
@@ -287,7 +289,7 @@ class GithubPlugin(SmartPlugin):
287289 PLUGIN_VERSION = '1.0.1'
288290 REPO_DIR = 'priv_repos'
289291
290- def loggerr (self , msg ):
292+ def loggerr (self , msg : str ):
291293 """ log error message and raise GPError to signal WebIf """
292294 self .logger .error (msg )
293295 raise GPError (msg )
@@ -313,7 +315,9 @@ def __init__(self, sh):
313315 # 'link': os.path.join('plugins', f'priv_{plugin}'), # absoluter Pfad-/Dateiname des Plugin-Symlinks
314316 # 'rel_link_path': os.path.join(wt_path, plugin), # Ziel der Plugin-Symlinks: relativer Pfad des Ziel-Pluginordners "unterhalb" von plugins/
315317 # 'repo': repo, # git.Repo(path)
316- # 'clean': bool # repo is clean and synced?
318+ # 'clean': bool, # repo is clean and synced?
319+ # 'lcommit': str, # local commit head
320+ # 'rcommit': str # remote commit head
317321 # },
318322 # '<id2>': {...}
319323 # }
@@ -341,7 +345,7 @@ def __init__(self, sh):
341345 # methods for handling local repos
342346 #
343347
344- def read_repos_from_dir (self , exc = False ):
348+ def read_repos_from_dir (self , exc : bool = False ):
345349 # clear stored repos
346350 self .repos = {}
347351
@@ -400,19 +404,26 @@ def read_repos_from_dir(self, exc=False):
400404 'link' : str (item ),
401405 'rel_link_path' : str (target ),
402406 'repo' : repo ,
407+ 'lcommit' : '' ,
408+ 'rcommit' : ''
403409 }
404410 self .repos [name ]['clean' ] = self .is_repo_clean (name , exc )
411+
412+ # fill head commits for local and remote branches
413+ if not self .repos [name ]['lcommit' ] or not self .repos [name ]['rcommit' ]:
414+ self .get_head_commits (name )
415+
405416 self .logger .info (f'added plugin { plugin } with name { name } in { item } ' )
406417
407- def check_for_repo_name (self , name ) -> bool :
418+ def check_for_repo_name (self , name : str ) -> bool :
408419 """ check if name exists in repos or link exists """
409420 if name in self .repos or os .path .exists (os .path .join (self .plg_path , 'priv_' + name )):
410421 self .loggerr (f'name { name } already taken, delete old plugin first or choose a different name.' )
411422 return False
412423
413424 return True
414425
415- def create_repo (self , name , owner , plugin , branch = None , rename = False ) -> bool :
426+ def create_repo (self , name : str , owner : str , plugin : str , branch : str = '' , rename : bool = False ) -> bool :
416427 """ create repo from given parameters """
417428
418429 if any (x in name for x in ['/' , '..' ]) or name == self .REPO_DIR :
@@ -421,14 +432,14 @@ def create_repo(self, name, owner, plugin, branch=None, rename=False) -> bool:
421432
422433 if not self .supermode :
423434 if not name .startswith ('priv_' ):
424- self .loggerr (f'Invalid name, must start with "priv_"' )
435+ self .loggerr (f'Name { name } invalid , must start with "priv_"' )
425436 return False
426437
427438 if not rename :
428439 try :
429440 self .check_for_repo_name (name )
430441 except Exception as e :
431- self .loggerr (e )
442+ self .loggerr (e . __repr__ () )
432443 return False
433444
434445 if not owner or not plugin :
@@ -565,10 +576,11 @@ def create_repo(self, name, owner, plugin, branch=None, rename=False) -> bool:
565576 return False
566577
567578 self .repos [name ] = repo
579+ self .get_head_commits (name )
568580
569581 return True
570582
571- def _move_old_link (self , name ) -> bool :
583+ def _move_old_link (self , name : str ) -> bool :
572584 if not self .supermode and not os .path .basename (name ).startswith ('priv_' ):
573585 self .loggerr (f'unable to move plugin with illegal name { name } ' )
574586 return False
@@ -609,7 +621,7 @@ def _move_old_link(self, name) -> bool:
609621 self .loggerr (f'error renaming old plugin: { e } ' )
610622 return False
611623
612- def _rmtree (self , path ):
624+ def _rmtree (self , path : str ):
613625 """ remove path tree, also try to remove .DS_Store if present """
614626 try :
615627 rmtree (path )
@@ -627,7 +639,7 @@ def _rmtree(self, path):
627639 # Try again, but finally give up if error persists
628640 rmtree (path )
629641
630- def remove_plugin (self , name ) -> bool :
642+ def remove_plugin (self , name : str ) -> bool :
631643 """ remove plugin link, worktree and if not longer needed, local repo """
632644 if name not in self .repos :
633645 self .loggerr (f'plugin entry { name } not found.' )
@@ -709,46 +721,55 @@ def remove_plugin(self, name) -> bool:
709721 # github API methods
710722 #
711723
712- def is_repo_clean (self , name : str , exc = False ) -> bool :
713- """ checks if worktree is clean and local and remote branches are in sync """
714- if name not in self .repos :
715- self .loggerr (f'repo { name } not found' )
716- return False
724+ def get_head_commits (self , name : str , exc : bool = False ) -> bool :
725+ """ tries to get current local and remote head commits """
717726
718727 entry = self .repos [name ]
719728 local = entry ['repo' ]
720729
721- # abort if worktree isn't clean
722- if local .is_dirty () or local .untracked_files != []:
723- self .logger .debug (f'repo { name } : dirty: { local .is_dirty ()} , untracked files: { local .untracked_files } ' )
724- self .repos [name ]['clean' ] = False
725- return False
726-
727730 # get remote and local branch heads
728731 try :
729732 remote = self .gh .get_repo (entry ['owner' ], entry ['gh_repo' ])
730733 r_branch = remote .get_branch (branch = entry ['branch' ])
731- r_head = r_branch .commit .sha
732-
733- l_head = local . heads [ entry [ 'branch' ]]. commit . hexsha
734+ entry [ 'rcommit' ] = r_branch .commit .sha
735+ entry [ 'lcommit' ] = local . heads [ entry [ 'branch' ]]. commit . hexsha
736+ return True
734737 except AttributeError :
735738 if exc :
736739 f = self .loggerr
737740 else :
738741 f = self .logger .warning
739- f (f'error while checking sync status for { name } . Rate limit active?' )
742+ f (f'error while getting commits for { name } . Rate limit active?' )
740743 return False
741744 except Exception as e :
742- self .loggerr (f'error while checking sync status for { name } : { e } ' )
745+ self .loggerr (f'error while getting commits for { name } : { e } ' )
743746 return False
744747
745- clean = l_head == r_head
748+ def is_repo_clean (self , name : str , exc : bool = False ) -> bool :
749+ """ checks if worktree is clean and local and remote branches are in sync """
750+ if name not in self .repos :
751+ self .loggerr (f'repo { name } not found' )
752+ return False
753+
754+ entry = self .repos [name ]
755+ local = entry ['repo' ]
756+
757+ # abort if worktree isn't clean
758+ if local .is_dirty () or local .untracked_files != []:
759+ self .logger .debug (f'repo { name } : dirty: { local .is_dirty ()} , untracked files: { local .untracked_files } ' )
760+ self .repos [name ]['clean' ] = False
761+ return False
762+
763+ if not self .get_head_commits (name , exc ):
764+ return False
765+
766+ clean = entry ['lcommit' ] == entry ['rcommit' ]
746767 if not clean :
747768 try :
748- _ = list (repo .iter_commits (r_head ))
769+ _ = list (local .iter_commits (entry [ 'rcommit' ] ))
749770 # as clean is excluded, we must be ahead. Possibly out changes are not saved, so stay as "not clean""
750771 pass
751- except git . exc . GitCommandError :
772+ except GitCommandError :
752773 # commit not in local, we are not clean and not ahead, so we are behind
753774 # being beind with clean worktree means nothing gets lost or overwritten. Allow operations
754775 clean = True
@@ -784,9 +805,11 @@ def pull_repo(self, name: str) -> bool:
784805
785806 try :
786807 org .pull ()
808+ self .get_head_commits (name )
787809 return True
788810 except Exception as e :
789811 self .loggerr (f'error while pulling: { e } ' )
812+ return False
790813
791814 def setup_github (self ) -> bool :
792815 """ login to github and set repo """
@@ -798,21 +821,21 @@ def setup_github(self) -> bool:
798821
799822 return self .gh .set_repo ()
800823
801- def fetch_github_forks (self , fetch = False ) -> bool :
824+ def fetch_github_forks (self , fetch : bool = False ) -> bool :
802825 """ fetch forks from github API """
803826 if self .gh :
804827 return self .gh .get_forks (fetch = fetch )
805828 else :
806829 return False
807830
808- def fetch_github_pulls (self , fetch = False ) -> bool :
831+ def fetch_github_pulls (self , fetch : bool = False ) -> bool :
809832 """ fetch PRs from github API """
810833 if self .gh :
811834 return self .gh .get_pulls (fetch = fetch )
812835 else :
813836 return False
814837
815- def fetch_github_branches_from (self , fork = None , owner = '' , fetch = False ) -> dict :
838+ def fetch_github_branches_from (self , fork : Union [ Repo , None ] = None , owner : str = '' , fetch : bool = False ) -> dict :
816839 """
817840 fetch branches for given fork from github API
818841
@@ -821,11 +844,11 @@ def fetch_github_branches_from(self, fork=None, owner='', fetch=False) -> dict:
821844 """
822845 return self .gh .get_branches_from (fork = fork , owner = owner , fetch = fetch )
823846
824- def fetch_github_plugins_from (self , fork = None , owner = '' , branch = '' , fetch = False ) -> list :
847+ def fetch_github_plugins_from (self , fork : Union [ Repo , None ] = None , owner : str = '' , branch : str = '' , fetch : bool = False ) -> list :
825848 """ fetch plugin names for selected fork/branch """
826849 return self .gh .get_plugins_from (fork = fork , owner = owner , branch = branch , fetch = fetch )
827850
828- def get_github_forks (self , owner = None ) -> dict :
851+ def get_github_forks (self , owner : str = '' ) -> dict :
829852 """ return forks or single fork for given owner """
830853 if owner :
831854 return self .gh .forks .get (owner , {})
@@ -853,7 +876,7 @@ def get_github_forklist_sorted(self) -> list:
853876
854877 return sforkstop + sforks
855878
856- def get_github_pulls (self , number = None ) -> dict :
879+ def get_github_pulls (self , number : int = 0 ) -> dict :
857880 """ return pulls or single pull for given number """
858881 if number :
859882 return self .gh .pulls .get (number , {})
@@ -865,7 +888,7 @@ def get_github_pulls(self, number=None) -> dict:
865888 #
866889
867890 # unused right now, possibly remove?
868- def create_repo_from_gh (self , number = 0 , owner = '' , branch = None , plugin = '' ) -> bool :
891+ def create_repo_from_gh (self , number : int = 0 , owner : str = '' , branch : Union [ Repo , str , None ] = None , plugin : str = '' ) -> bool :
869892 """
870893 call init/create methods to download new repo and create worktree
871894
@@ -959,7 +982,7 @@ def stop(self):
959982 # helper methods
960983 #
961984
962- def _get_last_3_path_parts (self , path ) :
985+ def _get_last_3_path_parts (self , path : str ) -> Tuple [ str , str , str ] :
963986 """ return last 3 parts of a path """
964987 try :
965988 head , l3 = os .path .split (path )
0 commit comments