Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset random number generator for reproducibility #301

Merged
merged 1 commit into from
Jan 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,13 @@ def from_uirt(cls, data, seed=None):
"""
return cls.build(data, fmt="UIRT", seed=seed)

def reset(self):
"""Reset the random number generator for reproducibility"""
self.rng = get_rng(self.seed)
return self

def num_batches(self, batch_size):
"""Estimate number of batches per epoch"""
return estimate_batches(len(self.uir_tuple[0]), batch_size)

def idx_iter(self, idx_range, batch_size=1, shuffle=False):
Expand Down
12 changes: 10 additions & 2 deletions cornac/eval_methods/base_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..metrics import RatingMetric
from ..metrics import RankingMetric
from ..experiment.result import Result
from ..utils import get_rng


def rating_eval(model, metrics, test_set, user_based=False):
Expand Down Expand Up @@ -239,6 +240,7 @@ def __init__(
self.exclude_unknowns = exclude_unknowns
self.verbose = verbose
self.seed = seed
self.rng = get_rng(seed)
self.global_uid_map = OrderedDict()
self.global_iid_map = OrderedDict()

Expand Down Expand Up @@ -362,6 +364,11 @@ def sentiment(self, input_modality):
)
self.__sentiment = input_modality

def _reset(self):
"""Reset the random number generator for reproducibility"""
self.rng = get_rng(self.seed)
self.test_set = self.test_set.reset()

def _organize_metrics(self, metrics):
"""Organize metrics according to their types (rating or raking)

Expand Down Expand Up @@ -559,12 +566,13 @@ def evaluate(self, model, metrics, user_based, show_validation=True):
res: :obj:`cornac.experiment.Result`

"""
self._organize_metrics(metrics)

if self.train_set is None:
raise ValueError("train_set is required but None!")
if self.test_set is None:
raise ValueError("test_set is required but None!")

self._reset()
self._organize_metrics(metrics)

###########
# FITTING #
Expand Down
7 changes: 2 additions & 5 deletions cornac/eval_methods/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,14 @@ def __init__(

def _partition_data(self):
"""Partition ratings into n_folds"""

rng = get_rng(self.seed)

fold_size = int(self.n_ratings / self.n_folds)
remain_size = self.n_ratings - fold_size * self.n_folds

partition = np.repeat(np.arange(self.n_folds), fold_size)
rng.shuffle(partition)
self.rng.shuffle(partition)

if remain_size > 0:
remain_partition = rng.choice(
remain_partition = self.rng.choice(
self.n_folds, size=remain_size, replace=True, p=None
)
partition = np.concatenate((partition, remain_partition))
Expand Down
2 changes: 1 addition & 1 deletion cornac/eval_methods/ratio_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def validate_size(val_size, test_size, num_ratings):
return int(train_size), int(val_size), int(test_size)

def _split(self):
data_idx = get_rng(self.seed).permutation(len(self._data))
data_idx = self.rng.permutation(len(self._data))
train_idx = data_idx[:self.train_size]
test_idx = data_idx[-self.test_size:]
val_idx = data_idx[self.train_size:-self.test_size]
Expand Down
4 changes: 2 additions & 2 deletions cornac/models/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def fit(self, train_set, val_set=None):
self : object
"""
self.reset_info()
self.train_set = train_set
self.val_set = val_set
self.train_set = train_set.reset()
self.val_set = None if val_set is None else val_set.reset()
return self

def score(self, user_idx, item_idx=None):
Expand Down