@@ -159,4 +159,95 @@ def check_options_validity(split_option: dict):
159
159
160
160
return _splits
161
161
162
- # %% Classes
162
+
163
+ def mmd_subsample_fn (
164
+ X : np .ndarray , size : int , initial_ids : list [int ] = None , memory_safe : bool = False
165
+ ) -> np .ndarray :
166
+ """Selects samples in the input dataset by greedily minimizing the maximum mena discrepancy (MMD).
167
+
168
+ Args:
169
+ X(np.ndarray): input dataset of shape n_samples x n_features
170
+ size(int): number of samples to select
171
+ initial_ids(list[int]): a list of ids of points to initialize the gready algorithm. Defaults to None.
172
+ memory_safe(bool): if True, avoids a memory expensive computation. Useful for large datasets. Defaults to False.
173
+
174
+ Returns:
175
+ np.ndarray: array of selected samples
176
+
177
+ Example:
178
+ .. code-block:: python
179
+
180
+ # Let X be drawn from a standard 10-dimensional Gaussian distribution
181
+ np.random.seed(0)
182
+ X = np.random.randn(1000,10)
183
+
184
+ # Select 100 particles
185
+ idx = mmd_subsample_fn(X, size=100)
186
+
187
+ print(idx)
188
+ >>> [765 113 171 727 796 855 715 207 458 603 23 384 860 3 459 708 794 138
189
+ 221 639 8 816 619 806 398 236 36 404 167 87 201 676 961 624 556 840
190
+ 485 975 283 150 554 409 69 769 332 357 388 216 900 134 15 730 80 694
191
+ 251 714 11 817 525 382 328 67 356 514 597 668 959 260 968 26 209 789
192
+ 305 122 989 571 801 322 14 160 908 12 1 980 582 440 42 452 666 526
193
+ 290 231 712 21 606 575 656 950 879 948]
194
+
195
+ # In this simple Gaussian example, the means and standard deviations of the
196
+ # selected subsample should be close to the ones of the original sample
197
+
198
+ print(np.abs(np.mean(x[idx], axis=0) - np.mean(x, axis=0)))
199
+ >>> [0.00280955 0.00220179 0.01359079 0.00461107 0.0011997 0.01106616
200
+ 0.01157571 0.0061314 0.00813494 0.0026543]
201
+ print(np.abs(np.std(x[idx], axis=0) - np.std(x, axis=0)))
202
+ >>> [0.0067711 0.00316008 0.00860733 0.07130127 0.02858514 0.0173707
203
+ 0.00739646 0.03526784 0.0054039 0.00351996]
204
+ """
205
+
206
+ n = X .shape [0 ]
207
+ assert size <= n
208
+ # Precompute norms and distance matrix
209
+ norms = np .linalg .norm (X , axis = 1 )
210
+ if memory_safe :
211
+ k0_mean = np .zeros (n )
212
+ for i in range (n ):
213
+ kxy = norms [i : i + 1 , None ] + norms [None , :] - cdist (X [i : i + 1 ], X )
214
+ k0_mean [i ] = np .mean (kxy )
215
+ else :
216
+ dist_matrix = cdist (X , X )
217
+ gram_matrix = norms [:, None ] + norms [None , :] - dist_matrix
218
+ k0_mean = np .mean (gram_matrix , axis = 1 )
219
+
220
+ idx = np .zeros (size , dtype = np .int64 )
221
+ if initial_ids is None or len (initial_ids ) == 0 :
222
+ k0 = np .zeros ((n , size ))
223
+ k0 [:, 0 ] = 2.0 * norms
224
+
225
+ idx [0 ] = np .argmin (k0 [:, 0 ] - 2.0 * k0_mean )
226
+ for i in range (1 , size ):
227
+ x_ = X [idx [i - 1 ]]
228
+ dist = np .linalg .norm (X - x_ , axis = 1 )
229
+ k0 [:, i ] = - dist + norms [idx [i - 1 ]] + norms
230
+
231
+ idx [i ] = np .argmin (
232
+ k0 [:, 0 ]
233
+ + 2.0 * np .sum (k0 [:, 1 : (i + 1 )], axis = 1 )
234
+ - 2.0 * (i + 1 ) * k0_mean
235
+ )
236
+ else :
237
+ assert len (initial_ids ) < size
238
+ idx [: len (initial_ids )] = initial_ids
239
+ k0 = np .zeros ((n , size ))
240
+
241
+ k0 [:, 0 ] = 2.0 * norms
242
+ for i in range (1 , size ):
243
+ x_ = X [idx [i - 1 ]]
244
+ dist = np .linalg .norm (X - x_ , axis = 1 )
245
+ k0 [:, i ] = - dist + norms [idx [i - 1 ]] + norms
246
+
247
+ if i >= len (initial_ids ):
248
+ idx [i ] = np .argmin (
249
+ k0 [:, 0 ]
250
+ + 2.0 * np .sum (k0 [:, 1 : (i + 1 )], axis = 1 )
251
+ - 2.0 * (i + 1 ) * k0_mean
252
+ )
253
+ return idx
0 commit comments