1
- use ndarray:: { Array2 , ArrayView1 , ArrayView2 , Zip } ;
1
+ use ndarray:: { Array2 , ArrayView1 , ArrayView2 , ArrayViewMut1 } ;
2
2
use numpy:: { PyArray2 , PyReadonlyArray1 , PyReadonlyArray2 } ;
3
3
use pyo3:: prelude:: * ;
4
4
use rayon:: prelude:: * ;
5
- use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
5
+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
6
6
use std:: sync:: Arc ;
7
7
8
- mod internal {
9
- pub ( super ) fn sad_converged ( a : & [ f64 ] , b : & [ f64 ] , tol : f64 ) -> bool {
10
- a. iter ( ) . zip ( b) . all ( |( & x, & y) | ( x - y) . abs ( ) < tol)
8
+
9
+ #[ allow( non_snake_case) ]
10
+ pub fn demean_impl ( X : & mut Array2 < f64 > , D : ArrayView2 < usize > , weights : ArrayView1 < f64 > , tol : f64 , iterations : usize ) -> bool {
11
+ let nsamples = X . nrows ( ) ;
12
+ let nfactors = D . ncols ( ) ;
13
+ let success = Arc :: new ( AtomicBool :: new ( true ) ) ;
14
+ let group_weights = FactorGroupWeights :: new ( & D , & weights) ;
15
+
16
+ X . axis_iter_mut ( ndarray:: Axis ( 1 ) )
17
+ . into_par_iter ( )
18
+ . for_each ( |mut column| {
19
+ let mut demeaner = ColumnDemeaner :: new ( nsamples, group_weights. width ) ;
20
+
21
+ for _ in 0 ..iterations {
22
+ for i in 0 ..nfactors {
23
+ demeaner. demean_column (
24
+ & mut column,
25
+ & weights,
26
+ & D . column ( i) ,
27
+ group_weights. factor_weight_slice ( i)
28
+ ) ;
29
+ }
30
+
31
+ demeaner. check_convergence ( & column. view ( ) , tol) ;
32
+ if demeaner. converged {
33
+ break ;
34
+ }
35
+ }
36
+
37
+ if !demeaner. converged {
38
+ // We can use a relaxed ordering since we only ever go from true to false
39
+ // and it doesn't matter how many times we do this.
40
+ success. store ( false , Ordering :: Relaxed ) ;
41
+ }
42
+ } ) ;
43
+
44
+ success. load ( Ordering :: Relaxed )
45
+ }
46
+
47
+ // The column demeaner is in charge of subtracting group means until convergence.
48
+ struct ColumnDemeaner {
49
+ converged : bool ,
50
+ checkpoint : Vec < f64 > ,
51
+ group_sums : Vec < f64 > ,
52
+ }
53
+
54
+ impl ColumnDemeaner {
55
+ fn new ( n : usize , k : usize ) -> Self {
56
+ Self {
57
+ converged : false ,
58
+ checkpoint : vec ! [ 0.0 ; n] ,
59
+ group_sums : vec ! [ 0.0 ; k] ,
60
+ }
11
61
}
12
62
13
- pub ( super ) fn subtract_weighted_group_mean (
14
- x : & mut [ f64 ] ,
15
- sample_weights : & [ f64 ] ,
16
- group_ids : & [ usize ] ,
63
+ fn demean_column (
64
+ & mut self ,
65
+ x : & mut ArrayViewMut1 < f64 > ,
66
+ weights : & ArrayView1 < f64 > ,
67
+ groups : & ArrayView1 < usize > ,
17
68
group_weights : & [ f64 ] ,
18
- group_weighted_sums : & mut [ f64 ] ,
19
69
) {
20
- group_weighted_sums . fill ( 0.0 ) ;
70
+ self . group_sums . fill ( 0.0 ) ;
21
71
22
- // Accumulate weighted sums per group
23
- x. iter ( )
24
- . zip ( sample_weights)
25
- . zip ( group_ids)
26
- . for_each ( |( ( & xi, & wi) , & gid) | {
27
- group_weighted_sums[ gid] += wi * xi;
28
- } ) ;
72
+ // Compute group sums
73
+ for ( ( & xi, & wi) , & gid) in x. iter ( ) . zip ( weights) . zip ( groups) {
74
+ self . group_sums [ gid] += wi * xi;
75
+ }
29
76
30
- // Compute group means
31
- let group_means: Vec < f64 > = group_weighted_sums
32
- . iter ( )
33
- . zip ( group_weights)
34
- . map ( |( & sum, & weight) | sum / weight)
35
- . collect ( ) ;
77
+ // Convert sums to means
78
+ self . group_sums
79
+ . iter_mut ( )
80
+ . zip ( group_weights. iter ( ) )
81
+ . for_each ( |( sum, & weight) | {
82
+ * sum /= weight
83
+ } ) ;
36
84
37
- // Subtract means from each sample
38
- x. iter_mut ( ) . zip ( group_ids ) . for_each ( | ( xi , & gid ) | {
39
- * xi -= group_means [ gid] ;
40
- } ) ;
85
+ // Subtract group means
86
+ for ( xi , & gid ) in x. iter_mut ( ) . zip ( groups ) {
87
+ * xi -= self . group_sums [ gid] // Really these are means now
88
+ }
41
89
}
42
90
43
- pub ( super ) fn calc_group_weights (
44
- sample_weights : & [ f64 ] ,
45
- group_ids : & [ usize ] ,
46
- n_samples : usize ,
47
- n_factors : usize ,
48
- n_groups : usize ,
49
- ) -> Vec < f64 > {
50
- let mut group_weights = vec ! [ 0.0 ; n_factors * n_groups ] ;
51
- for i in 0 ..n_samples {
52
- let weight = sample_weights [ i ] ;
53
- for j in 0 ..n_factors {
54
- let id = group_ids [ i * n_factors + j ] ;
55
- group_weights [ j * n_groups + id ] += weight ;
56
- }
57
- }
58
- group_weights
91
+
92
+ // Check elementwise convergence and update checkpoint
93
+ fn check_convergence (
94
+ & mut self ,
95
+ x : & ArrayView1 < f64 > ,
96
+ tol : f64 ,
97
+ ) {
98
+ self . converged = true ; // Innocent until proven guilty
99
+ x . iter ( )
100
+ . zip ( self . checkpoint . iter_mut ( ) )
101
+ . for_each ( | ( & xi , cp ) | {
102
+ if ( xi - * cp ) . abs ( ) > tol {
103
+ self . converged = false ; // Guilty!
104
+ }
105
+ * cp = xi ; // Update checkpoint
106
+ } ) ;
59
107
}
60
108
}
61
109
62
- fn demean_impl (
63
- x : & ArrayView2 < f64 > ,
64
- flist : & ArrayView2 < usize > ,
65
- weights : & ArrayView1 < f64 > ,
66
- tol : f64 ,
67
- maxiter : usize ,
68
- ) -> ( Array2 < f64 > , bool ) {
69
- let ( n_samples, n_features) = x. dim ( ) ;
70
- let n_factors = flist. ncols ( ) ;
71
- let n_groups = flist. iter ( ) . cloned ( ) . max ( ) . unwrap ( ) + 1 ;
72
-
73
- let sample_weights: Vec < f64 > = weights. iter ( ) . cloned ( ) . collect ( ) ;
74
- let group_ids: Vec < usize > = flist. iter ( ) . cloned ( ) . collect ( ) ;
75
- let group_weights =
76
- internal:: calc_group_weights ( & sample_weights, & group_ids, n_samples, n_factors, n_groups) ;
77
-
78
- let not_converged = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
79
-
80
- // Precompute slices of group_ids for each factor
81
- let group_ids_by_factor: Vec < Vec < usize > > = ( 0 ..n_factors)
82
- . map ( |j| {
83
- ( 0 ..n_samples)
84
- . map ( |i| group_ids[ i * n_factors + j] )
85
- . collect ( )
86
- } )
87
- . collect ( ) ;
88
-
89
- // Precompute group weight slices
90
- let group_weight_slices: Vec < & [ f64 ] > = ( 0 ..n_factors)
91
- . map ( |j| & group_weights[ j * n_groups..( j + 1 ) * n_groups] )
92
- . collect ( ) ;
93
-
94
- let process_column = |( k, mut col) : ( usize , ndarray:: ArrayViewMut1 < f64 > ) | {
95
- let mut xk_curr: Vec < f64 > = ( 0 ..n_samples) . map ( |i| x[ [ i, k] ] ) . collect ( ) ;
96
- let mut xk_prev: Vec < f64 > = xk_curr. iter ( ) . map ( |& v| v - 1.0 ) . collect ( ) ;
97
- let mut gw_sums = vec ! [ 0.0 ; n_groups] ;
98
-
99
- let mut converged = false ;
100
- for _ in 0 ..maxiter {
101
- for j in 0 ..n_factors {
102
- internal:: subtract_weighted_group_mean (
103
- & mut xk_curr,
104
- & sample_weights,
105
- & group_ids_by_factor[ j] ,
106
- group_weight_slices[ j] ,
107
- & mut gw_sums,
108
- ) ;
109
- }
110
+ // Instead of recomputing the denominators for the weighted group averages every time,
111
+ // we'll precompute them and store them in a grid-like structure. The grid will have
112
+ // dimensions (m, k) where m is the number of factors and k is the maximum group ID.
113
+ struct FactorGroupWeights {
114
+ values : Vec < f64 > ,
115
+ width : usize ,
116
+ }
117
+
118
+ impl FactorGroupWeights {
119
+ fn new ( flist : & ArrayView2 < usize > , weights : & ArrayView1 < f64 > ) -> Self {
120
+ let n_samples = flist. nrows ( ) ;
121
+ let n_factors = flist. ncols ( ) ;
122
+ let width = flist. iter ( ) . max ( ) . unwrap ( ) + 1 ;
110
123
111
- if internal:: sad_converged ( & xk_curr, & xk_prev, tol) {
112
- converged = true ;
113
- break ;
124
+ let mut values = vec ! [ 0.0 ; n_factors * width] ;
125
+ for i in 0 ..n_samples {
126
+ let weight = weights[ i] ;
127
+ for j in 0 ..n_factors {
128
+ let id = flist[ [ i, j] ] ;
129
+ values[ j * width + id] += weight;
114
130
}
115
- xk_prev. copy_from_slice ( & xk_curr) ;
116
131
}
117
132
118
- if !converged {
119
- not_converged. fetch_add ( 1 , Ordering :: SeqCst ) ;
133
+ Self {
134
+ values,
135
+ width,
120
136
}
121
- Zip :: from ( & mut col) . and ( & xk_curr) . for_each ( |col_elm, & val| {
122
- * col_elm = val;
123
- } ) ;
124
- } ;
125
-
126
- let mut res = Array2 :: < f64 > :: zeros ( ( n_samples, n_features) ) ;
127
-
128
- res. axis_iter_mut ( ndarray:: Axis ( 1 ) )
129
- . into_par_iter ( )
130
- . enumerate ( )
131
- . for_each ( process_column) ;
137
+ }
132
138
133
- let success = not_converged. load ( Ordering :: SeqCst ) == 0 ;
134
- ( res, success)
139
+ fn factor_weight_slice ( & self , factor_index : usize ) -> & [ f64 ] {
140
+ & self . values [ factor_index * self . width ..( factor_index + 1 ) * self . width ]
141
+ }
135
142
}
136
143
137
144
@@ -196,7 +203,6 @@ fn demean_impl(
196
203
/// print(x_demeaned)
197
204
/// print("Converged:", converged)
198
205
/// ```
199
-
200
206
#[ pyfunction]
201
207
#[ pyo3( signature = ( x, flist, weights, tol=1e-8 , maxiter=100_000 ) ) ]
202
208
pub fn _demean_rs (
@@ -207,13 +213,18 @@ pub fn _demean_rs(
207
213
tol : f64 ,
208
214
maxiter : usize ,
209
215
) -> PyResult < ( Py < PyArray2 < f64 > > , bool ) > {
210
- let x_arr = x. as_array ( ) ;
211
- let flist_arr = flist. as_array ( ) ;
212
- let weights_arr = weights. as_array ( ) ;
213
-
214
- let ( out, success) =
215
- py. allow_threads ( || demean_impl ( & x_arr, & flist_arr, & weights_arr, tol, maxiter) ) ;
216
-
217
- let pyarray = PyArray2 :: from_owned_array ( py, out) ;
218
- Ok ( ( pyarray. into_py ( py) , success) )
216
+ let mut x_array = x. as_array ( ) . to_owned ( ) ;
217
+ let flist_array = flist. as_array ( ) ;
218
+ let weights_array = weights. as_array ( ) ;
219
+
220
+ let converged = demean_impl (
221
+ & mut x_array,
222
+ flist_array,
223
+ weights_array,
224
+ tol,
225
+ maxiter,
226
+ ) ;
227
+
228
+ let pyarray = PyArray2 :: from_owned_array ( py, x_array) ;
229
+ Ok ( ( pyarray. into_py ( py) , converged) )
219
230
}
0 commit comments