diff --git a/gitrisky/cli.py b/gitrisky/cli.py index 38dbd8e..1032d7f 100644 --- a/gitrisky/cli.py +++ b/gitrisky/cli.py @@ -14,7 +14,9 @@ def cli(): @cli.command() -def train(): +@click.option('-p', '--pattern', required=False, + help="Bug fix pattern. Ex. BUG,FIX", type=str) +def train(pattern=None): """Train a git commit bug risk model. This will save a pickled sklearn model to a file in the toplevel directory @@ -23,7 +25,10 @@ def train(): # get the features and labels by parsing the git logs features = get_features() - labels = get_labels() + if pattern is not None: + labels = get_labels(pattern.split(',')) + else: + labels = get_labels(pattern) # instantiate and train a model model = create_model() diff --git a/gitrisky/gitcmds.py b/gitrisky/gitcmds.py index 72593a3..c0f7841 100644 --- a/gitrisky/gitcmds.py +++ b/gitrisky/gitcmds.py @@ -3,7 +3,7 @@ import re from collections import defaultdict -from subprocess import check_output +from subprocess import check_output, CalledProcessError def _run_bash_command(bash_cmd): @@ -20,7 +20,11 @@ def _run_bash_command(bash_cmd): The resulting stdout output. """ - stdout = check_output(bash_cmd.split()).decode('utf-8').rstrip('\n') + try: + stdout = check_output(bash_cmd.split()).decode('utf-8').rstrip('\n') + except CalledProcessError as err: + print('Failed to execute bash command: {!r}'.format(str(bash_cmd))) + exit(1) return stdout @@ -80,7 +84,7 @@ def get_git_log(commit=None): return stdout -def get_bugfix_commits(): +def get_bugfix_commits(pattern=None): """Get the commits whose commit messages contain BUG or FIX. Returns @@ -89,8 +93,12 @@ def get_bugfix_commits(): A list of commit hashes. """ + if pattern is None: + pattern = ("BUG", "FIX") + # TODO: add option to specify custom bugfix tags - bash_cmd = "git log -i --all --grep BUG --grep FIX --pretty=format:%h" + bash_cmd = "git log -i --all --grep {} --grep {} --pretty=format:%h"\ + .format(pattern[0], pattern[1]) stdout = _run_bash_command(bash_cmd) diff --git a/gitrisky/parsing.py b/gitrisky/parsing.py index 4e5af0a..39d3571 100644 --- a/gitrisky/parsing.py +++ b/gitrisky/parsing.py @@ -135,7 +135,7 @@ def get_features(commit=None): return feats -def get_labels(): +def get_labels(pattern=None): """Get a label for each commit indicating whether it introduced a bug. Returns @@ -147,7 +147,7 @@ def get_labels(): feats = get_features() - fix_commits = get_bugfix_commits() + fix_commits = get_bugfix_commits(pattern) bug_commits = link_fixes_to_bugs(fix_commits)