diff options
author | Konstantin Ryabitsev <konstantin@linuxfoundation.org> | 2021-12-14 16:01:42 -0500 |
---|---|---|
committer | Konstantin Ryabitsev <konstantin@linuxfoundation.org> | 2021-12-14 16:01:42 -0500 |
commit | db381d559031ce8ff3899ac79d62dbace4db902f (patch) | |
tree | 07e9a04f646bd296cc4fb9cc0e4d4c36ba3d925e | |
parent | a48d14b2cd13540d0ad1a65e504b7d5ac19c15be (diff) | |
download | korg-helpers-db381d559031ce8ff3899ac79d62dbace4db902f.tar.gz |
patchwork-bot: add proper typing hints
It helps to catch easy bugs when writing in PyCharm and other IDEs.
Signed-off-by: Konstantin Ryabitsev <konstantin@linuxfoundation.org>
-rwxr-xr-x | git-patchwork-bot.py | 132 |
1 files changed, 75 insertions, 57 deletions
diff --git a/git-patchwork-bot.py b/git-patchwork-bot.py index a9bc771..6adef71 100755 --- a/git-patchwork-bot.py +++ b/git-patchwork-bot.py @@ -39,6 +39,7 @@ from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry from string import Template +from typing import Optional, Tuple, Union, Dict, List, Set # Send all email 8-bit, this is not 1999 from email import charset @@ -66,7 +67,15 @@ logger = logging.getLogger('gitpwcron') class Restmaker: - def __init__(self, server): + server: str + url: str + series_url: str + patches_url: str + projects_url: str + session: requests.Session + _patches: Dict[int, Optional[dict]] + + def __init__(self, server: str) -> None: self.server = server self.url = '/'.join((server.rstrip('/'), 'api', REST_API_VERSION)) @@ -91,7 +100,7 @@ class Restmaker: headers['Authorization'] = f'Token {apitoken}' self.session.headers.update(headers) - def get_unpaginated(self, url, params): + def get_unpaginated(self, url: str, params: list) -> List[dict]: # Caller should catch RequestException page = 0 results = list() @@ -113,7 +122,7 @@ class Restmaker: return results - def get_cover(self, cover_id): + def get_cover(self, cover_id: int) -> dict: try: logger.debug('Grabbing cover %d', cover_id) url = '/'.join((self.covers_url, str(cover_id), '')) @@ -123,9 +132,9 @@ class Restmaker: return rsp.json() except requests.exceptions.RequestException as ex: logger.info('REST error: %s', ex) - return None + raise KeyError('Not able to get cover %s', cover_id) - def get_patch(self, patch_id): + def get_patch(self, patch_id: int) -> dict: if patch_id not in self._patches: try: logger.debug('Grabbing patch %d', patch_id) @@ -137,23 +146,23 @@ class Restmaker: except requests.exceptions.RequestException as ex: logger.info('REST error: %s', ex) self._patches[patch_id] = None + raise KeyError('Not able to get patch_id %s', patch_id) return self._patches[patch_id] - def get_series(self, series_id): + def get_series(self, series_id: int) -> dict: try: logger.debug('Grabbing series %d', series_id) url = '/'.join((self.series_url, str(series_id), '')) logger.debug('url=%s', url) rsp = self.session.get(url, stream=False) rsp.raise_for_status() + return rsp.json() except requests.exceptions.RequestException as ex: logger.info('REST error: %s', ex) - return None - - return rsp.json() + raise KeyError('Not able to get series %s', series_id) - def get_patches_list(self, params, unpaginated=True): + def get_patches_list(self, params: list, unpaginated: bool = True) -> List[dict]: try: if unpaginated: return self.get_unpaginated(self.patches_url, params) @@ -163,10 +172,9 @@ class Restmaker: return rsp.json() except requests.exceptions.RequestException as ex: logger.info('REST error: %s', ex) + return list() - return None - - def get_series_list(self, params, unpaginated=True): + def get_series_list(self, params: list, unpaginated: bool = True) -> List[dict]: try: if unpaginated: return self.get_unpaginated(self.series_url, params) @@ -176,18 +184,17 @@ class Restmaker: return rsp.json() except requests.exceptions.RequestException as ex: logger.info('REST error: %s', ex) + return list() - return None - - def get_projects_list(self, params): + def get_projects_list(self, params: list) -> list: try: return self.get_unpaginated(self.projects_url, params) except requests.exceptions.RequestException as ex: logger.info('REST error: %s', ex) + return list() - return None - - def update_patch(self, patch_id, state=None, archived=False, commit_ref=None): + def update_patch(self, patch_id: int, state: Optional[str] = None, archived: bool = False, + commit_ref: Optional[str] = None) -> list: # Clear it out of the cache if patch_id in self._patches: del self._patches[patch_id] @@ -211,12 +218,12 @@ class Restmaker: rsp.raise_for_status() except requests.exceptions.RequestException as ex: logger.info('REST error: %s', ex) - return None + raise RuntimeError('Unable to update patch %s', patch_id) return rsp.json() -def get_patchwork_patches_by_project_hash(rm, project, pwhash): +def get_patchwork_patches_by_project_hash(rm: Restmaker, project: int, pwhash: str) -> List[int]: logger.debug('Looking up %s', pwhash) params = [ ('project', project), @@ -226,12 +233,12 @@ def get_patchwork_patches_by_project_hash(rm, project, pwhash): patches = rm.get_patches_list(params) if not patches: logger.debug('No match for hash=%s', pwhash) - return None + return list() return [patch['id'] for patch in patches] -def get_patchwork_pull_requests_by_project(rm, project, fromstate): +def get_patchwork_pull_requests_by_project(rm: Restmaker, project: int, fromstate: List[str]) -> Set[Tuple]: params = [ ('project', project), ('archived', 'false'), @@ -256,16 +263,15 @@ def get_patchwork_pull_requests_by_project(rm, project, fromstate): pull_refname = 'master' prs.add((pull_host, pull_refname, patch_id)) - return prs -def project_by_name(pname): +def project_by_name(pname: str) -> Tuple: global _project_cache global _server_cache if not pname: - return None + raise KeyError('Must specify project name') if pname not in _project_cache: # Find patchwork definition containing this project @@ -302,27 +308,27 @@ def project_by_name(pname): break if not found: logger.info('Could not find project matching %s on server %s', pname, server) - return None + raise KeyError(f'No match for project {pname} on server {server}') return _project_cache[pname] -def db_save_meta(c): +def db_save_meta(c: sqlite3.Cursor) -> None: c.execute('DELETE FROM meta') c.execute('''INSERT INTO meta VALUES(?)''', (DB_VERSION,)) -def db_save_repo_heads(c, heads): +def db_save_repo_heads(c: sqlite3.Cursor, heads: list) -> None: c.execute('DELETE FROM heads') for refname, commit_id in heads: c.execute('''INSERT INTO heads VALUES(?,?)''', (refname, commit_id)) -def db_get_repo_heads(c): +def db_get_repo_heads(c: sqlite3.Cursor) -> List[Tuple]: return c.execute('SELECT refname, commit_id FROM heads').fetchall() -def db_init_common_sqlite_db(c): +def db_init_common_sqlite_db(c: sqlite3.Cursor) -> None: c.execute(''' CREATE TABLE meta ( version INTEGER @@ -330,7 +336,7 @@ def db_init_common_sqlite_db(c): db_save_meta(c) -def db_init_cache_sqlite_db(c): +def db_init_cache_sqlite_db(c: sqlite3.Cursor) -> None: logger.info('Initializing new sqlite3 db with metadata version %s', DB_VERSION) db_init_common_sqlite_db(c) c.execute(''' @@ -343,7 +349,7 @@ def db_init_cache_sqlite_db(c): c.execute('''CREATE UNIQUE INDEX idx_rev ON revs(rev)''') -def db_init_pw_sqlite_db(c): +def db_init_pw_sqlite_db(c: sqlite3.Cursor) -> None: logger.info('Initializing new sqlite3 db with metadata version %s', DB_VERSION) db_init_common_sqlite_db(c) c.execute(''' @@ -353,7 +359,7 @@ def db_init_pw_sqlite_db(c): )''') -def git_get_command_lines(gitdir, args): +def git_get_command_lines(gitdir: str, args: List[str]) -> list: out = git_run_command(gitdir, args) lines = list() if out: @@ -365,7 +371,7 @@ def git_get_command_lines(gitdir, args): return lines -def git_run_command(gitdir, args, stdin=None): +def git_run_command(gitdir: str, args: List[str], stdin: Optional[str] = None) -> str: args = ['git', '--no-pager', '--git-dir', gitdir] + args logger.debug('Running %s' % ' '.join(args)) @@ -383,7 +389,7 @@ def git_run_command(gitdir, args, stdin=None): return output -def git_get_repo_heads(gitdir, branch, ancestry=None): +def git_get_repo_heads(gitdir: str, branch: str, ancestry: Optional[str] = None) -> List[Tuple[str, str]]: refs = list() lines = git_get_command_lines(gitdir, ['show-ref', branch]) if ancestry is None: @@ -397,7 +403,8 @@ def git_get_repo_heads(gitdir, branch, ancestry=None): return refs -def git_get_new_revs(gitdir, db_heads, git_heads, committers, merges=False): +def git_get_new_revs(gitdir: str, db_heads: List[Tuple[str, str]], git_heads: List[Tuple[str, str]], + committers: List[str], merges: bool = False) -> Dict[str, list]: newrevs = dict() if committers: logger.debug('filtering by committers=%s', committers) @@ -453,12 +460,12 @@ def git_get_new_revs(gitdir, db_heads, git_heads, committers, merges=False): return newrevs -def git_get_rev_diff(gitdir, rev): +def git_get_rev_diff(gitdir: str, rev: str) -> str: args = ['diff', '%s~..%s' % (rev, rev)] return git_run_command(gitdir, args) -def git_get_patch_id(diff): +def git_get_patch_id(diff: str) -> Optional[str]: args = ['patch-id', '--stable'] out = git_run_command('', args, stdin=diff) logger.debug('out=%s', out) @@ -467,7 +474,7 @@ def git_get_patch_id(diff): return out.split()[0] -def get_patchwork_hash(diff): +def get_patchwork_hash(diff: str) -> str: """Generate a hash from a diff. Lifted verbatim from patchwork.""" # normalise spaces @@ -515,13 +522,14 @@ def get_patchwork_hash(diff): return hashed.hexdigest() -def listify(obj): +def listify(obj: Union[str, list, None]) -> list: if isinstance(obj, list): return list(obj) return [obj] -def send_summary(serieslist, committers, to_state, refname, pname, rs, hs): +def send_summary(serieslist: List[dict], committers: Dict[int, str], to_state: str, refname: str, pname: str, + rs: Dict[str, str], hs: Dict[str, str]) -> str: logger.info('Preparing summary') # we send summaries by project, so the project name is going to be all the same @@ -610,7 +618,7 @@ def send_summary(serieslist, committers, to_state, refname, pname, rs, hs): return str(msg['Message-Id']) -def get_tweaks(pconfig, hconfig): +def get_tweaks(pconfig: Dict[str, str], hconfig: Dict[str, str]) -> Dict[str, str]: fields = ['from', 'summaryto', 'onlyto', 'neverto', 'onlyifcc', 'neverifcc', 'alwayscc', 'alwaysbcc', 'cclist', 'ccall'] bubbled = dict() @@ -623,7 +631,9 @@ def get_tweaks(pconfig, hconfig): return bubbled -def notify_submitters(serieslist, committers, refname, revs, pname, rs, hs): +def notify_submitters(serieslist: List[dict], committers: Dict[int, str], refname: str, + revs: Dict[int, str], pname: str, rs: Dict[str, Union[str, list, dict]], + hs: Dict[str, Union[str, list, dict]]) -> None: logger.info('Sending submitter notifications') project, rm, pconfig = project_by_name(pname) @@ -634,18 +644,26 @@ def notify_submitters(serieslist, committers, refname, revs, pname, rs, hs): # else the reference is the msgid of the first patch patches = sdata.get('patches') is_pull_request = False + content = headers = reference = None if sdata.get('cover_letter'): reference = sdata.get('cover_letter').get('msgid') - fullcover = rm.get_cover(sdata.get('cover_letter').get('id')) - headers = {k.lower(): v for k, v in fullcover.get('headers').items()} - content = fullcover.get('content') - else: + try: + fullcover = rm.get_cover(sdata.get('cover_letter').get('id')) + headers = {k.lower(): v for k, v in fullcover.get('headers').items()} + content = fullcover.get('content') + except KeyError: + logger.debug('Unable to get cover letter, will try first patch') + if not reference: reference = patches[0].get('msgid') - fullpatch = rm.get_patch(patches[0].get('id')) - headers = {k.lower(): v for k, v in fullpatch.get('headers').items()} - content = fullpatch.get('content') - if fullpatch.get('pull_url'): - is_pull_request = True + try: + fullpatch = rm.get_patch(patches[0].get('id')) + headers = {k.lower(): v for k, v in fullpatch.get('headers').items()} + content = fullpatch.get('content') + if fullpatch.get('pull_url'): + is_pull_request = True + except KeyError: + logger.debug('Unable to get first patch reference, bailing on %s', sdata.get('id')) + continue submitter = sdata.get('submitter') project = sdata.get('project') @@ -784,7 +802,7 @@ def notify_submitters(serieslist, committers, refname, revs, pname, rs, hs): logger.info('------------------------------') -def housekeeping(pname): +def housekeeping(pname: str) -> None: project, rm, pconfig = project_by_name(pname) if 'housekeeping' not in pconfig: return @@ -1030,7 +1048,7 @@ def housekeeping(pname): logger.info('------------------------------') -def pwrun(repo, rsettings): +def pwrun(repo: str, rsettings: Dict[str, Union[str, list, dict]]) -> None: git_heads = git_get_repo_heads(repo, branch=rsettings.get('branch', '--heads')) if not git_heads: logger.info('Could not get the latest ref in %s', repo) @@ -1288,7 +1306,7 @@ def pwrun(repo, rsettings): dbconn.commit() -def check_repos(): +def check_repos() -> None: # Use a global lock to make sure only a single process is running try: lockfh = open(os.path.join(CACHEDIR, 'patchwork-bot.global.lock'), 'w') @@ -1310,7 +1328,7 @@ def check_repos(): pwrun(fullpath, settings) -def pwhash_differ(): +def pwhash_differ() -> None: diff = sys.stdin.read() pwhash = get_patchwork_hash(diff) print(pwhash) |