Skip to content

Commit

Permalink
viisar#14 refactored combine method to support weighted combination r…
Browse files Browse the repository at this point in the history
…ules
  • Loading branch information
ivolima committed Oct 28, 2016
1 parent 48a5647 commit 605e33b
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions brew/combination/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

class Combiner(object):

__VALID_WEIGHTED_COMBINATION_RULES = [
rules.majority_vote_rule,
rules.mean_rule,
]

def __init__(self, rule='majority_vote'):
self.combination_rule = rule

Expand All @@ -25,13 +30,38 @@ def __init__(self, rule='majority_vote'):
else:
raise Exception('invalid argument rule for Combiner class')

def combine(self, results):
def combine(self, results, weights=None):
"""
This method puts together the results of all classifiers
based on a pre-selected combination rule.
Parameters
----------
results: array-like, shape = [n_samples, n_classes, n_classifiers]
If combination rule is 'majority_vote' results should be Ensemble.output(X, mode='votes')
Otherwise, Ensemble.output(X, mode='probs')
weights: array-like, optional(default=None)
Weights of the classifiers. Must have the same size of n_classifiers.
Applies only to 'majority_vote' and 'mean' combination rules.
"""

nresults = results.copy().astype(float)
n_samples = nresults.shape[0]
y_pred = np.zeros((n_samples,))

n_samples = results.shape[0]
if weights is not None:
# verify valid combination rules
if self.rule in __VALID_WEIGHTED_COMBINATION_RULES:
# verify shapes
if weights.shape[0] != nresults.shape[2]:
raise Exception(
'weights and classifiers must have same size')

out = np.zeros((n_samples,))
# apply weights
for i in range(nresults.shape[2]):
nresults[:, :, i] = nresults[:, :, i] * weights[i]

for i in range(n_samples):
out[i] = self.rule(results[i, :, :])
y_pred[i] = self.rule(nresults[i, :, :])

return out
return y_pred

0 comments on commit 605e33b

Please sign in to comment.