diff --git a/cornac/data/dataset.py b/cornac/data/dataset.py index 27496a754..c7f0a4522 100644 --- a/cornac/data/dataset.py +++ b/cornac/data/dataset.py @@ -426,7 +426,13 @@ def from_uirt(cls, data, seed=None): """ return cls.build(data, fmt="UIRT", seed=seed) + def reset(self): + """Reset the random number generator for reproducibility""" + self.rng = get_rng(self.seed) + return self + def num_batches(self, batch_size): + """Estimate number of batches per epoch""" return estimate_batches(len(self.uir_tuple[0]), batch_size) def idx_iter(self, idx_range, batch_size=1, shuffle=False): diff --git a/cornac/eval_methods/base_method.py b/cornac/eval_methods/base_method.py index 02be510ef..c097419d3 100644 --- a/cornac/eval_methods/base_method.py +++ b/cornac/eval_methods/base_method.py @@ -28,6 +28,7 @@ from ..metrics import RatingMetric from ..metrics import RankingMetric from ..experiment.result import Result +from ..utils import get_rng def rating_eval(model, metrics, test_set, user_based=False): @@ -239,6 +240,7 @@ def __init__( self.exclude_unknowns = exclude_unknowns self.verbose = verbose self.seed = seed + self.rng = get_rng(seed) self.global_uid_map = OrderedDict() self.global_iid_map = OrderedDict() @@ -362,6 +364,11 @@ def sentiment(self, input_modality): ) self.__sentiment = input_modality + def _reset(self): + """Reset the random number generator for reproducibility""" + self.rng = get_rng(self.seed) + self.test_set = self.test_set.reset() + def _organize_metrics(self, metrics): """Organize metrics according to their types (rating or raking) @@ -559,12 +566,13 @@ def evaluate(self, model, metrics, user_based, show_validation=True): res: :obj:`cornac.experiment.Result` """ - self._organize_metrics(metrics) - if self.train_set is None: raise ValueError("train_set is required but None!") if self.test_set is None: raise ValueError("test_set is required but None!") + + self._reset() + self._organize_metrics(metrics) ########### # FITTING # diff --git a/cornac/eval_methods/cross_validation.py b/cornac/eval_methods/cross_validation.py index ba0391ca5..7b0d11d51 100644 --- a/cornac/eval_methods/cross_validation.py +++ b/cornac/eval_methods/cross_validation.py @@ -81,17 +81,14 @@ def __init__( def _partition_data(self): """Partition ratings into n_folds""" - - rng = get_rng(self.seed) - fold_size = int(self.n_ratings / self.n_folds) remain_size = self.n_ratings - fold_size * self.n_folds partition = np.repeat(np.arange(self.n_folds), fold_size) - rng.shuffle(partition) + self.rng.shuffle(partition) if remain_size > 0: - remain_partition = rng.choice( + remain_partition = self.rng.choice( self.n_folds, size=remain_size, replace=True, p=None ) partition = np.concatenate((partition, remain_partition)) diff --git a/cornac/eval_methods/ratio_split.py b/cornac/eval_methods/ratio_split.py index 549a0a46a..fd1202812 100644 --- a/cornac/eval_methods/ratio_split.py +++ b/cornac/eval_methods/ratio_split.py @@ -93,7 +93,7 @@ def validate_size(val_size, test_size, num_ratings): return int(train_size), int(val_size), int(test_size) def _split(self): - data_idx = get_rng(self.seed).permutation(len(self._data)) + data_idx = self.rng.permutation(len(self._data)) train_idx = data_idx[:self.train_size] test_idx = data_idx[-self.test_size:] val_idx = data_idx[self.train_size:-self.test_size] diff --git a/cornac/models/recommender.py b/cornac/models/recommender.py index 2845ccb24..419fef05b 100644 --- a/cornac/models/recommender.py +++ b/cornac/models/recommender.py @@ -96,8 +96,8 @@ def fit(self, train_set, val_set=None): self : object """ self.reset_info() - self.train_set = train_set - self.val_set = val_set + self.train_set = train_set.reset() + self.val_set = None if val_set is None else val_set.reset() return self def score(self, user_idx, item_idx=None):