Skip to content

Commit 4e9e1c6

Browse files
xroynardfabiencasenave
authored andcommitted
(split) improve doc of MMD function
1 parent 36fb6a4 commit 4e9e1c6

File tree

1 file changed

+92
-1
lines changed

1 file changed

+92
-1
lines changed

src/plaid/utils/split.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,95 @@ def check_options_validity(split_option: dict):
159159

160160
return _splits
161161

162-
# %% Classes
162+
163+
def mmd_subsample_fn(
164+
X: np.ndarray, size: int, initial_ids: list[int] = None, memory_safe: bool = False
165+
) -> np.ndarray:
166+
"""Selects samples in the input dataset by greedily minimizing the maximum mena discrepancy (MMD).
167+
168+
Args:
169+
X(np.ndarray): input dataset of shape n_samples x n_features
170+
size(int): number of samples to select
171+
initial_ids(list[int]): a list of ids of points to initialize the gready algorithm. Defaults to None.
172+
memory_safe(bool): if True, avoids a memory expensive computation. Useful for large datasets. Defaults to False.
173+
174+
Returns:
175+
np.ndarray: array of selected samples
176+
177+
Example:
178+
.. code-block:: python
179+
180+
# Let X be drawn from a standard 10-dimensional Gaussian distribution
181+
np.random.seed(0)
182+
X = np.random.randn(1000,10)
183+
184+
# Select 100 particles
185+
idx = mmd_subsample_fn(X, size=100)
186+
187+
print(idx)
188+
>>> [765 113 171 727 796 855 715 207 458 603 23 384 860 3 459 708 794 138
189+
221 639 8 816 619 806 398 236 36 404 167 87 201 676 961 624 556 840
190+
485 975 283 150 554 409 69 769 332 357 388 216 900 134 15 730 80 694
191+
251 714 11 817 525 382 328 67 356 514 597 668 959 260 968 26 209 789
192+
305 122 989 571 801 322 14 160 908 12 1 980 582 440 42 452 666 526
193+
290 231 712 21 606 575 656 950 879 948]
194+
195+
# In this simple Gaussian example, the means and standard deviations of the
196+
# selected subsample should be close to the ones of the original sample
197+
198+
print(np.abs(np.mean(x[idx], axis=0) - np.mean(x, axis=0)))
199+
>>> [0.00280955 0.00220179 0.01359079 0.00461107 0.0011997 0.01106616
200+
0.01157571 0.0061314 0.00813494 0.0026543]
201+
print(np.abs(np.std(x[idx], axis=0) - np.std(x, axis=0)))
202+
>>> [0.0067711 0.00316008 0.00860733 0.07130127 0.02858514 0.0173707
203+
0.00739646 0.03526784 0.0054039 0.00351996]
204+
"""
205+
206+
n = X.shape[0]
207+
assert size <= n
208+
# Precompute norms and distance matrix
209+
norms = np.linalg.norm(X, axis=1)
210+
if memory_safe:
211+
k0_mean = np.zeros(n)
212+
for i in range(n):
213+
kxy = norms[i : i + 1, None] + norms[None, :] - cdist(X[i : i + 1], X)
214+
k0_mean[i] = np.mean(kxy)
215+
else:
216+
dist_matrix = cdist(X, X)
217+
gram_matrix = norms[:, None] + norms[None, :] - dist_matrix
218+
k0_mean = np.mean(gram_matrix, axis=1)
219+
220+
idx = np.zeros(size, dtype=np.int64)
221+
if initial_ids is None or len(initial_ids) == 0:
222+
k0 = np.zeros((n, size))
223+
k0[:, 0] = 2.0 * norms
224+
225+
idx[0] = np.argmin(k0[:, 0] - 2.0 * k0_mean)
226+
for i in range(1, size):
227+
x_ = X[idx[i - 1]]
228+
dist = np.linalg.norm(X - x_, axis=1)
229+
k0[:, i] = -dist + norms[idx[i - 1]] + norms
230+
231+
idx[i] = np.argmin(
232+
k0[:, 0]
233+
+ 2.0 * np.sum(k0[:, 1 : (i + 1)], axis=1)
234+
- 2.0 * (i + 1) * k0_mean
235+
)
236+
else:
237+
assert len(initial_ids) < size
238+
idx[: len(initial_ids)] = initial_ids
239+
k0 = np.zeros((n, size))
240+
241+
k0[:, 0] = 2.0 * norms
242+
for i in range(1, size):
243+
x_ = X[idx[i - 1]]
244+
dist = np.linalg.norm(X - x_, axis=1)
245+
k0[:, i] = -dist + norms[idx[i - 1]] + norms
246+
247+
if i >= len(initial_ids):
248+
idx[i] = np.argmin(
249+
k0[:, 0]
250+
+ 2.0 * np.sum(k0[:, 1 : (i + 1)], axis=1)
251+
- 2.0 * (i + 1) * k0_mean
252+
)
253+
return idx

0 commit comments

Comments
 (0)