Skip to content

🎉 Mmd split of input scalars #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion src/plaid/utils/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,95 @@ def check_options_validity(split_option: dict):

return _splits

# %% Classes

def mmd_subsample_fn(
X: np.ndarray, size: int, initial_ids: list[int] = None, memory_safe: bool = False
) -> np.ndarray:
"""Selects samples in the input dataset by greedily minimizing the maximum mena discrepancy (MMD).

Args:
X(np.ndarray): input dataset of shape n_samples x n_features
size(int): number of samples to select
initial_ids(list[int]): a list of ids of points to initialize the gready algorithm. Defaults to None.
memory_safe(bool): if True, avoids a memory expensive computation. Useful for large datasets. Defaults to False.

Returns:
np.ndarray: array of selected samples

Example:
.. code-block:: python

# Let X be drawn from a standard 10-dimensional Gaussian distribution
np.random.seed(0)
X = np.random.randn(1000,10)

# Select 100 particles
idx = mmd_subsample_fn(X, size=100)

print(idx)
>>> [765 113 171 727 796 855 715 207 458 603 23 384 860 3 459 708 794 138
221 639 8 816 619 806 398 236 36 404 167 87 201 676 961 624 556 840
485 975 283 150 554 409 69 769 332 357 388 216 900 134 15 730 80 694
251 714 11 817 525 382 328 67 356 514 597 668 959 260 968 26 209 789
305 122 989 571 801 322 14 160 908 12 1 980 582 440 42 452 666 526
290 231 712 21 606 575 656 950 879 948]

# In this simple Gaussian example, the means and standard deviations of the
# selected subsample should be close to the ones of the original sample

print(np.abs(np.mean(x[idx], axis=0) - np.mean(x, axis=0)))
>>> [0.00280955 0.00220179 0.01359079 0.00461107 0.0011997 0.01106616
0.01157571 0.0061314 0.00813494 0.0026543]
print(np.abs(np.std(x[idx], axis=0) - np.std(x, axis=0)))
>>> [0.0067711 0.00316008 0.00860733 0.07130127 0.02858514 0.0173707
0.00739646 0.03526784 0.0054039 0.00351996]
"""

n = X.shape[0]
assert size <= n
# Precompute norms and distance matrix
norms = np.linalg.norm(X, axis=1)
if memory_safe:
k0_mean = np.zeros(n)
for i in range(n):
kxy = norms[i : i + 1, None] + norms[None, :] - cdist(X[i : i + 1], X)
k0_mean[i] = np.mean(kxy)
else:
dist_matrix = cdist(X, X)
gram_matrix = norms[:, None] + norms[None, :] - dist_matrix
k0_mean = np.mean(gram_matrix, axis=1)

idx = np.zeros(size, dtype=np.int64)
if initial_ids is None or len(initial_ids) == 0:
k0 = np.zeros((n, size))
k0[:, 0] = 2.0 * norms

idx[0] = np.argmin(k0[:, 0] - 2.0 * k0_mean)
for i in range(1, size):
x_ = X[idx[i - 1]]
dist = np.linalg.norm(X - x_, axis=1)
k0[:, i] = -dist + norms[idx[i - 1]] + norms

idx[i] = np.argmin(
k0[:, 0]
+ 2.0 * np.sum(k0[:, 1 : (i + 1)], axis=1)
- 2.0 * (i + 1) * k0_mean
)
else:
assert len(initial_ids) < size
idx[: len(initial_ids)] = initial_ids
k0 = np.zeros((n, size))

k0[:, 0] = 2.0 * norms
for i in range(1, size):
x_ = X[idx[i - 1]]
dist = np.linalg.norm(X - x_, axis=1)
k0[:, i] = -dist + norms[idx[i - 1]] + norms

if i >= len(initial_ids):
idx[i] = np.argmin(
k0[:, 0]
+ 2.0 * np.sum(k0[:, 1 : (i + 1)], axis=1)
- 2.0 * (i + 1) * k0_mean
)
return idx
Loading