aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKonstantin Ryabitsev <konstantin@linuxfoundation.org>2021-12-14 16:01:42 -0500
committerKonstantin Ryabitsev <konstantin@linuxfoundation.org>2021-12-14 16:01:42 -0500
commitdb381d559031ce8ff3899ac79d62dbace4db902f (patch)
tree07e9a04f646bd296cc4fb9cc0e4d4c36ba3d925e
parenta48d14b2cd13540d0ad1a65e504b7d5ac19c15be (diff)
downloadkorg-helpers-db381d559031ce8ff3899ac79d62dbace4db902f.tar.gz
patchwork-bot: add proper typing hints
It helps to catch easy bugs when writing in PyCharm and other IDEs. Signed-off-by: Konstantin Ryabitsev <konstantin@linuxfoundation.org>
-rwxr-xr-xgit-patchwork-bot.py132
1 files changed, 75 insertions, 57 deletions
diff --git a/git-patchwork-bot.py b/git-patchwork-bot.py
index a9bc771..6adef71 100755
--- a/git-patchwork-bot.py
+++ b/git-patchwork-bot.py
@@ -39,6 +39,7 @@ from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from string import Template
+from typing import Optional, Tuple, Union, Dict, List, Set
# Send all email 8-bit, this is not 1999
from email import charset
@@ -66,7 +67,15 @@ logger = logging.getLogger('gitpwcron')
class Restmaker:
- def __init__(self, server):
+ server: str
+ url: str
+ series_url: str
+ patches_url: str
+ projects_url: str
+ session: requests.Session
+ _patches: Dict[int, Optional[dict]]
+
+ def __init__(self, server: str) -> None:
self.server = server
self.url = '/'.join((server.rstrip('/'), 'api', REST_API_VERSION))
@@ -91,7 +100,7 @@ class Restmaker:
headers['Authorization'] = f'Token {apitoken}'
self.session.headers.update(headers)
- def get_unpaginated(self, url, params):
+ def get_unpaginated(self, url: str, params: list) -> List[dict]:
# Caller should catch RequestException
page = 0
results = list()
@@ -113,7 +122,7 @@ class Restmaker:
return results
- def get_cover(self, cover_id):
+ def get_cover(self, cover_id: int) -> dict:
try:
logger.debug('Grabbing cover %d', cover_id)
url = '/'.join((self.covers_url, str(cover_id), ''))
@@ -123,9 +132,9 @@ class Restmaker:
return rsp.json()
except requests.exceptions.RequestException as ex:
logger.info('REST error: %s', ex)
- return None
+ raise KeyError('Not able to get cover %s', cover_id)
- def get_patch(self, patch_id):
+ def get_patch(self, patch_id: int) -> dict:
if patch_id not in self._patches:
try:
logger.debug('Grabbing patch %d', patch_id)
@@ -137,23 +146,23 @@ class Restmaker:
except requests.exceptions.RequestException as ex:
logger.info('REST error: %s', ex)
self._patches[patch_id] = None
+ raise KeyError('Not able to get patch_id %s', patch_id)
return self._patches[patch_id]
- def get_series(self, series_id):
+ def get_series(self, series_id: int) -> dict:
try:
logger.debug('Grabbing series %d', series_id)
url = '/'.join((self.series_url, str(series_id), ''))
logger.debug('url=%s', url)
rsp = self.session.get(url, stream=False)
rsp.raise_for_status()
+ return rsp.json()
except requests.exceptions.RequestException as ex:
logger.info('REST error: %s', ex)
- return None
-
- return rsp.json()
+ raise KeyError('Not able to get series %s', series_id)
- def get_patches_list(self, params, unpaginated=True):
+ def get_patches_list(self, params: list, unpaginated: bool = True) -> List[dict]:
try:
if unpaginated:
return self.get_unpaginated(self.patches_url, params)
@@ -163,10 +172,9 @@ class Restmaker:
return rsp.json()
except requests.exceptions.RequestException as ex:
logger.info('REST error: %s', ex)
+ return list()
- return None
-
- def get_series_list(self, params, unpaginated=True):
+ def get_series_list(self, params: list, unpaginated: bool = True) -> List[dict]:
try:
if unpaginated:
return self.get_unpaginated(self.series_url, params)
@@ -176,18 +184,17 @@ class Restmaker:
return rsp.json()
except requests.exceptions.RequestException as ex:
logger.info('REST error: %s', ex)
+ return list()
- return None
-
- def get_projects_list(self, params):
+ def get_projects_list(self, params: list) -> list:
try:
return self.get_unpaginated(self.projects_url, params)
except requests.exceptions.RequestException as ex:
logger.info('REST error: %s', ex)
+ return list()
- return None
-
- def update_patch(self, patch_id, state=None, archived=False, commit_ref=None):
+ def update_patch(self, patch_id: int, state: Optional[str] = None, archived: bool = False,
+ commit_ref: Optional[str] = None) -> list:
# Clear it out of the cache
if patch_id in self._patches:
del self._patches[patch_id]
@@ -211,12 +218,12 @@ class Restmaker:
rsp.raise_for_status()
except requests.exceptions.RequestException as ex:
logger.info('REST error: %s', ex)
- return None
+ raise RuntimeError('Unable to update patch %s', patch_id)
return rsp.json()
-def get_patchwork_patches_by_project_hash(rm, project, pwhash):
+def get_patchwork_patches_by_project_hash(rm: Restmaker, project: int, pwhash: str) -> List[int]:
logger.debug('Looking up %s', pwhash)
params = [
('project', project),
@@ -226,12 +233,12 @@ def get_patchwork_patches_by_project_hash(rm, project, pwhash):
patches = rm.get_patches_list(params)
if not patches:
logger.debug('No match for hash=%s', pwhash)
- return None
+ return list()
return [patch['id'] for patch in patches]
-def get_patchwork_pull_requests_by_project(rm, project, fromstate):
+def get_patchwork_pull_requests_by_project(rm: Restmaker, project: int, fromstate: List[str]) -> Set[Tuple]:
params = [
('project', project),
('archived', 'false'),
@@ -256,16 +263,15 @@ def get_patchwork_pull_requests_by_project(rm, project, fromstate):
pull_refname = 'master'
prs.add((pull_host, pull_refname, patch_id))
-
return prs
-def project_by_name(pname):
+def project_by_name(pname: str) -> Tuple:
global _project_cache
global _server_cache
if not pname:
- return None
+ raise KeyError('Must specify project name')
if pname not in _project_cache:
# Find patchwork definition containing this project
@@ -302,27 +308,27 @@ def project_by_name(pname):
break
if not found:
logger.info('Could not find project matching %s on server %s', pname, server)
- return None
+ raise KeyError(f'No match for project {pname} on server {server}')
return _project_cache[pname]
-def db_save_meta(c):
+def db_save_meta(c: sqlite3.Cursor) -> None:
c.execute('DELETE FROM meta')
c.execute('''INSERT INTO meta VALUES(?)''', (DB_VERSION,))
-def db_save_repo_heads(c, heads):
+def db_save_repo_heads(c: sqlite3.Cursor, heads: list) -> None:
c.execute('DELETE FROM heads')
for refname, commit_id in heads:
c.execute('''INSERT INTO heads VALUES(?,?)''', (refname, commit_id))
-def db_get_repo_heads(c):
+def db_get_repo_heads(c: sqlite3.Cursor) -> List[Tuple]:
return c.execute('SELECT refname, commit_id FROM heads').fetchall()
-def db_init_common_sqlite_db(c):
+def db_init_common_sqlite_db(c: sqlite3.Cursor) -> None:
c.execute('''
CREATE TABLE meta (
version INTEGER
@@ -330,7 +336,7 @@ def db_init_common_sqlite_db(c):
db_save_meta(c)
-def db_init_cache_sqlite_db(c):
+def db_init_cache_sqlite_db(c: sqlite3.Cursor) -> None:
logger.info('Initializing new sqlite3 db with metadata version %s', DB_VERSION)
db_init_common_sqlite_db(c)
c.execute('''
@@ -343,7 +349,7 @@ def db_init_cache_sqlite_db(c):
c.execute('''CREATE UNIQUE INDEX idx_rev ON revs(rev)''')
-def db_init_pw_sqlite_db(c):
+def db_init_pw_sqlite_db(c: sqlite3.Cursor) -> None:
logger.info('Initializing new sqlite3 db with metadata version %s', DB_VERSION)
db_init_common_sqlite_db(c)
c.execute('''
@@ -353,7 +359,7 @@ def db_init_pw_sqlite_db(c):
)''')
-def git_get_command_lines(gitdir, args):
+def git_get_command_lines(gitdir: str, args: List[str]) -> list:
out = git_run_command(gitdir, args)
lines = list()
if out:
@@ -365,7 +371,7 @@ def git_get_command_lines(gitdir, args):
return lines
-def git_run_command(gitdir, args, stdin=None):
+def git_run_command(gitdir: str, args: List[str], stdin: Optional[str] = None) -> str:
args = ['git', '--no-pager', '--git-dir', gitdir] + args
logger.debug('Running %s' % ' '.join(args))
@@ -383,7 +389,7 @@ def git_run_command(gitdir, args, stdin=None):
return output
-def git_get_repo_heads(gitdir, branch, ancestry=None):
+def git_get_repo_heads(gitdir: str, branch: str, ancestry: Optional[str] = None) -> List[Tuple[str, str]]:
refs = list()
lines = git_get_command_lines(gitdir, ['show-ref', branch])
if ancestry is None:
@@ -397,7 +403,8 @@ def git_get_repo_heads(gitdir, branch, ancestry=None):
return refs
-def git_get_new_revs(gitdir, db_heads, git_heads, committers, merges=False):
+def git_get_new_revs(gitdir: str, db_heads: List[Tuple[str, str]], git_heads: List[Tuple[str, str]],
+ committers: List[str], merges: bool = False) -> Dict[str, list]:
newrevs = dict()
if committers:
logger.debug('filtering by committers=%s', committers)
@@ -453,12 +460,12 @@ def git_get_new_revs(gitdir, db_heads, git_heads, committers, merges=False):
return newrevs
-def git_get_rev_diff(gitdir, rev):
+def git_get_rev_diff(gitdir: str, rev: str) -> str:
args = ['diff', '%s~..%s' % (rev, rev)]
return git_run_command(gitdir, args)
-def git_get_patch_id(diff):
+def git_get_patch_id(diff: str) -> Optional[str]:
args = ['patch-id', '--stable']
out = git_run_command('', args, stdin=diff)
logger.debug('out=%s', out)
@@ -467,7 +474,7 @@ def git_get_patch_id(diff):
return out.split()[0]
-def get_patchwork_hash(diff):
+def get_patchwork_hash(diff: str) -> str:
"""Generate a hash from a diff. Lifted verbatim from patchwork."""
# normalise spaces
@@ -515,13 +522,14 @@ def get_patchwork_hash(diff):
return hashed.hexdigest()
-def listify(obj):
+def listify(obj: Union[str, list, None]) -> list:
if isinstance(obj, list):
return list(obj)
return [obj]
-def send_summary(serieslist, committers, to_state, refname, pname, rs, hs):
+def send_summary(serieslist: List[dict], committers: Dict[int, str], to_state: str, refname: str, pname: str,
+ rs: Dict[str, str], hs: Dict[str, str]) -> str:
logger.info('Preparing summary')
# we send summaries by project, so the project name is going to be all the same
@@ -610,7 +618,7 @@ def send_summary(serieslist, committers, to_state, refname, pname, rs, hs):
return str(msg['Message-Id'])
-def get_tweaks(pconfig, hconfig):
+def get_tweaks(pconfig: Dict[str, str], hconfig: Dict[str, str]) -> Dict[str, str]:
fields = ['from', 'summaryto', 'onlyto', 'neverto', 'onlyifcc', 'neverifcc',
'alwayscc', 'alwaysbcc', 'cclist', 'ccall']
bubbled = dict()
@@ -623,7 +631,9 @@ def get_tweaks(pconfig, hconfig):
return bubbled
-def notify_submitters(serieslist, committers, refname, revs, pname, rs, hs):
+def notify_submitters(serieslist: List[dict], committers: Dict[int, str], refname: str,
+ revs: Dict[int, str], pname: str, rs: Dict[str, Union[str, list, dict]],
+ hs: Dict[str, Union[str, list, dict]]) -> None:
logger.info('Sending submitter notifications')
project, rm, pconfig = project_by_name(pname)
@@ -634,18 +644,26 @@ def notify_submitters(serieslist, committers, refname, revs, pname, rs, hs):
# else the reference is the msgid of the first patch
patches = sdata.get('patches')
is_pull_request = False
+ content = headers = reference = None
if sdata.get('cover_letter'):
reference = sdata.get('cover_letter').get('msgid')
- fullcover = rm.get_cover(sdata.get('cover_letter').get('id'))
- headers = {k.lower(): v for k, v in fullcover.get('headers').items()}
- content = fullcover.get('content')
- else:
+ try:
+ fullcover = rm.get_cover(sdata.get('cover_letter').get('id'))
+ headers = {k.lower(): v for k, v in fullcover.get('headers').items()}
+ content = fullcover.get('content')
+ except KeyError:
+ logger.debug('Unable to get cover letter, will try first patch')
+ if not reference:
reference = patches[0].get('msgid')
- fullpatch = rm.get_patch(patches[0].get('id'))
- headers = {k.lower(): v for k, v in fullpatch.get('headers').items()}
- content = fullpatch.get('content')
- if fullpatch.get('pull_url'):
- is_pull_request = True
+ try:
+ fullpatch = rm.get_patch(patches[0].get('id'))
+ headers = {k.lower(): v for k, v in fullpatch.get('headers').items()}
+ content = fullpatch.get('content')
+ if fullpatch.get('pull_url'):
+ is_pull_request = True
+ except KeyError:
+ logger.debug('Unable to get first patch reference, bailing on %s', sdata.get('id'))
+ continue
submitter = sdata.get('submitter')
project = sdata.get('project')
@@ -784,7 +802,7 @@ def notify_submitters(serieslist, committers, refname, revs, pname, rs, hs):
logger.info('------------------------------')
-def housekeeping(pname):
+def housekeeping(pname: str) -> None:
project, rm, pconfig = project_by_name(pname)
if 'housekeeping' not in pconfig:
return
@@ -1030,7 +1048,7 @@ def housekeeping(pname):
logger.info('------------------------------')
-def pwrun(repo, rsettings):
+def pwrun(repo: str, rsettings: Dict[str, Union[str, list, dict]]) -> None:
git_heads = git_get_repo_heads(repo, branch=rsettings.get('branch', '--heads'))
if not git_heads:
logger.info('Could not get the latest ref in %s', repo)
@@ -1288,7 +1306,7 @@ def pwrun(repo, rsettings):
dbconn.commit()
-def check_repos():
+def check_repos() -> None:
# Use a global lock to make sure only a single process is running
try:
lockfh = open(os.path.join(CACHEDIR, 'patchwork-bot.global.lock'), 'w')
@@ -1310,7 +1328,7 @@ def check_repos():
pwrun(fullpath, settings)
-def pwhash_differ():
+def pwhash_differ() -> None:
diff = sys.stdin.read()
pwhash = get_patchwork_hash(diff)
print(pwhash)