diff --git a/src/plaid/utils/split.py b/src/plaid/utils/split.py index dced12d..dfbaacd 100644 --- a/src/plaid/utils/split.py +++ b/src/plaid/utils/split.py @@ -159,4 +159,95 @@ def check_options_validity(split_option: dict): return _splits -# %% Classes + +def mmd_subsample_fn( + X: np.ndarray, size: int, initial_ids: list[int] = None, memory_safe: bool = False +) -> np.ndarray: + """Selects samples in the input dataset by greedily minimizing the maximum mena discrepancy (MMD). + + Args: + X(np.ndarray): input dataset of shape n_samples x n_features + size(int): number of samples to select + initial_ids(list[int]): a list of ids of points to initialize the gready algorithm. Defaults to None. + memory_safe(bool): if True, avoids a memory expensive computation. Useful for large datasets. Defaults to False. + + Returns: + np.ndarray: array of selected samples + + Example: + .. code-block:: python + + # Let X be drawn from a standard 10-dimensional Gaussian distribution + np.random.seed(0) + X = np.random.randn(1000,10) + + # Select 100 particles + idx = mmd_subsample_fn(X, size=100) + + print(idx) + >>> [765 113 171 727 796 855 715 207 458 603 23 384 860 3 459 708 794 138 + 221 639 8 816 619 806 398 236 36 404 167 87 201 676 961 624 556 840 + 485 975 283 150 554 409 69 769 332 357 388 216 900 134 15 730 80 694 + 251 714 11 817 525 382 328 67 356 514 597 668 959 260 968 26 209 789 + 305 122 989 571 801 322 14 160 908 12 1 980 582 440 42 452 666 526 + 290 231 712 21 606 575 656 950 879 948] + + # In this simple Gaussian example, the means and standard deviations of the + # selected subsample should be close to the ones of the original sample + + print(np.abs(np.mean(x[idx], axis=0) - np.mean(x, axis=0))) + >>> [0.00280955 0.00220179 0.01359079 0.00461107 0.0011997 0.01106616 + 0.01157571 0.0061314 0.00813494 0.0026543] + print(np.abs(np.std(x[idx], axis=0) - np.std(x, axis=0))) + >>> [0.0067711 0.00316008 0.00860733 0.07130127 0.02858514 0.0173707 + 0.00739646 0.03526784 0.0054039 0.00351996] + """ + + n = X.shape[0] + assert size <= n + # Precompute norms and distance matrix + norms = np.linalg.norm(X, axis=1) + if memory_safe: + k0_mean = np.zeros(n) + for i in range(n): + kxy = norms[i : i + 1, None] + norms[None, :] - cdist(X[i : i + 1], X) + k0_mean[i] = np.mean(kxy) + else: + dist_matrix = cdist(X, X) + gram_matrix = norms[:, None] + norms[None, :] - dist_matrix + k0_mean = np.mean(gram_matrix, axis=1) + + idx = np.zeros(size, dtype=np.int64) + if initial_ids is None or len(initial_ids) == 0: + k0 = np.zeros((n, size)) + k0[:, 0] = 2.0 * norms + + idx[0] = np.argmin(k0[:, 0] - 2.0 * k0_mean) + for i in range(1, size): + x_ = X[idx[i - 1]] + dist = np.linalg.norm(X - x_, axis=1) + k0[:, i] = -dist + norms[idx[i - 1]] + norms + + idx[i] = np.argmin( + k0[:, 0] + + 2.0 * np.sum(k0[:, 1 : (i + 1)], axis=1) + - 2.0 * (i + 1) * k0_mean + ) + else: + assert len(initial_ids) < size + idx[: len(initial_ids)] = initial_ids + k0 = np.zeros((n, size)) + + k0[:, 0] = 2.0 * norms + for i in range(1, size): + x_ = X[idx[i - 1]] + dist = np.linalg.norm(X - x_, axis=1) + k0[:, i] = -dist + norms[idx[i - 1]] + norms + + if i >= len(initial_ids): + idx[i] = np.argmin( + k0[:, 0] + + 2.0 * np.sum(k0[:, 1 : (i + 1)], axis=1) + - 2.0 * (i + 1) * k0_mean + ) + return idx