Skip to content

Commit c638fb0

Browse files
committed
Optimise demean_rs
1 parent 1eb5f18 commit c638fb0

File tree

1 file changed

+132
-121
lines changed

1 file changed

+132
-121
lines changed

src/demean.rs

Lines changed: 132 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,137 +1,144 @@
1-
use ndarray::{Array2, ArrayView1, ArrayView2, Zip};
1+
use ndarray::{Array2, ArrayView1, ArrayView2, ArrayViewMut1};
22
use numpy::{PyArray2, PyReadonlyArray1, PyReadonlyArray2};
33
use pyo3::prelude::*;
44
use rayon::prelude::*;
5-
use std::sync::atomic::{AtomicUsize, Ordering};
5+
use std::sync::atomic::{AtomicBool, Ordering};
66
use std::sync::Arc;
77

8-
mod internal {
9-
pub(super) fn sad_converged(a: &[f64], b: &[f64], tol: f64) -> bool {
10-
a.iter().zip(b).all(|(&x, &y)| (x - y).abs() < tol)
8+
9+
#[allow(non_snake_case)]
10+
pub fn demean_impl(X: &mut Array2<f64>, D: ArrayView2<usize>, weights: ArrayView1<f64>, tol: f64, iterations: usize) -> bool {
11+
let nsamples = X.nrows();
12+
let nfactors = D.ncols();
13+
let success = Arc::new(AtomicBool::new(true));
14+
let group_weights = FactorGroupWeights::new(&D, &weights);
15+
16+
X.axis_iter_mut(ndarray::Axis(1))
17+
.into_par_iter()
18+
.for_each(|mut column| {
19+
let mut demeaner = ColumnDemeaner::new(nsamples, group_weights.width);
20+
21+
for _ in 0..iterations {
22+
for i in 0..nfactors {
23+
demeaner.demean_column(
24+
&mut column,
25+
&weights,
26+
&D.column(i),
27+
group_weights.factor_weight_slice(i)
28+
);
29+
}
30+
31+
demeaner.check_convergence(&column.view(), tol);
32+
if demeaner.converged {
33+
break;
34+
}
35+
}
36+
37+
if !demeaner.converged {
38+
// We can use a relaxed ordering since we only ever go from true to false
39+
// and it doesn't matter how many times we do this.
40+
success.store(false, Ordering::Relaxed);
41+
}
42+
});
43+
44+
success.load(Ordering::Relaxed)
45+
}
46+
47+
// The column demeaner is in charge of subtracting group means until convergence.
48+
struct ColumnDemeaner {
49+
converged: bool,
50+
checkpoint: Vec<f64>,
51+
group_sums: Vec<f64>,
52+
}
53+
54+
impl ColumnDemeaner {
55+
fn new(n: usize, k: usize) -> Self {
56+
Self {
57+
converged: false,
58+
checkpoint: vec![0.0; n],
59+
group_sums: vec![0.0; k],
60+
}
1161
}
1262

13-
pub(super) fn subtract_weighted_group_mean(
14-
x: &mut [f64],
15-
sample_weights: &[f64],
16-
group_ids: &[usize],
63+
fn demean_column(
64+
&mut self,
65+
x: &mut ArrayViewMut1<f64>,
66+
weights: &ArrayView1<f64>,
67+
groups: &ArrayView1<usize>,
1768
group_weights: &[f64],
18-
group_weighted_sums: &mut [f64],
1969
) {
20-
group_weighted_sums.fill(0.0);
70+
self.group_sums.fill(0.0);
2171

22-
// Accumulate weighted sums per group
23-
x.iter()
24-
.zip(sample_weights)
25-
.zip(group_ids)
26-
.for_each(|((&xi, &wi), &gid)| {
27-
group_weighted_sums[gid] += wi * xi;
28-
});
72+
// Compute group sums
73+
for ((&xi, &wi), &gid) in x.iter().zip(weights).zip(groups) {
74+
self.group_sums[gid] += wi * xi;
75+
}
2976

30-
// Compute group means
31-
let group_means: Vec<f64> = group_weighted_sums
32-
.iter()
33-
.zip(group_weights)
34-
.map(|(&sum, &weight)| sum / weight)
35-
.collect();
77+
// Convert sums to means
78+
self.group_sums
79+
.iter_mut()
80+
.zip(group_weights.iter())
81+
.for_each(|(sum, &weight)| {
82+
*sum /= weight
83+
});
3684

37-
// Subtract means from each sample
38-
x.iter_mut().zip(group_ids).for_each(|(xi, &gid)| {
39-
*xi -= group_means[gid];
40-
});
85+
// Subtract group means
86+
for (xi, &gid) in x.iter_mut().zip(groups) {
87+
*xi -= self.group_sums[gid] // Really these are means now
88+
}
4189
}
4290

43-
pub(super) fn calc_group_weights(
44-
sample_weights: &[f64],
45-
group_ids: &[usize],
46-
n_samples: usize,
47-
n_factors: usize,
48-
n_groups: usize,
49-
) -> Vec<f64> {
50-
let mut group_weights = vec![0.0; n_factors * n_groups];
51-
for i in 0..n_samples {
52-
let weight = sample_weights[i];
53-
for j in 0..n_factors {
54-
let id = group_ids[i * n_factors + j];
55-
group_weights[j * n_groups + id] += weight;
56-
}
57-
}
58-
group_weights
91+
92+
// Check elementwise convergence and update checkpoint
93+
fn check_convergence(
94+
&mut self,
95+
x: &ArrayView1<f64>,
96+
tol: f64,
97+
) {
98+
self.converged = true; // Innocent until proven guilty
99+
x.iter()
100+
.zip(self.checkpoint.iter_mut())
101+
.for_each(|(&xi, cp)| {
102+
if (xi - *cp).abs() > tol {
103+
self.converged = false; // Guilty!
104+
}
105+
*cp = xi; // Update checkpoint
106+
});
59107
}
60108
}
61109

62-
fn demean_impl(
63-
x: &ArrayView2<f64>,
64-
flist: &ArrayView2<usize>,
65-
weights: &ArrayView1<f64>,
66-
tol: f64,
67-
maxiter: usize,
68-
) -> (Array2<f64>, bool) {
69-
let (n_samples, n_features) = x.dim();
70-
let n_factors = flist.ncols();
71-
let n_groups = flist.iter().cloned().max().unwrap() + 1;
72-
73-
let sample_weights: Vec<f64> = weights.iter().cloned().collect();
74-
let group_ids: Vec<usize> = flist.iter().cloned().collect();
75-
let group_weights =
76-
internal::calc_group_weights(&sample_weights, &group_ids, n_samples, n_factors, n_groups);
77-
78-
let not_converged = Arc::new(AtomicUsize::new(0));
79-
80-
// Precompute slices of group_ids for each factor
81-
let group_ids_by_factor: Vec<Vec<usize>> = (0..n_factors)
82-
.map(|j| {
83-
(0..n_samples)
84-
.map(|i| group_ids[i * n_factors + j])
85-
.collect()
86-
})
87-
.collect();
88-
89-
// Precompute group weight slices
90-
let group_weight_slices: Vec<&[f64]> = (0..n_factors)
91-
.map(|j| &group_weights[j * n_groups..(j + 1) * n_groups])
92-
.collect();
93-
94-
let process_column = |(k, mut col): (usize, ndarray::ArrayViewMut1<f64>)| {
95-
let mut xk_curr: Vec<f64> = (0..n_samples).map(|i| x[[i, k]]).collect();
96-
let mut xk_prev: Vec<f64> = xk_curr.iter().map(|&v| v - 1.0).collect();
97-
let mut gw_sums = vec![0.0; n_groups];
98-
99-
let mut converged = false;
100-
for _ in 0..maxiter {
101-
for j in 0..n_factors {
102-
internal::subtract_weighted_group_mean(
103-
&mut xk_curr,
104-
&sample_weights,
105-
&group_ids_by_factor[j],
106-
group_weight_slices[j],
107-
&mut gw_sums,
108-
);
109-
}
110+
// Instead of recomputing the denominators for the weighted group averages every time,
111+
// we'll precompute them and store them in a grid-like structure. The grid will have
112+
// dimensions (m, k) where m is the number of factors and k is the maximum group ID.
113+
struct FactorGroupWeights {
114+
values: Vec<f64>,
115+
width: usize,
116+
}
117+
118+
impl FactorGroupWeights {
119+
fn new(flist: &ArrayView2<usize>, weights: &ArrayView1<f64>) -> Self {
120+
let n_samples = flist.nrows();
121+
let n_factors = flist.ncols();
122+
let width = flist.iter().max().unwrap() + 1;
110123

111-
if internal::sad_converged(&xk_curr, &xk_prev, tol) {
112-
converged = true;
113-
break;
124+
let mut values = vec![0.0; n_factors * width];
125+
for i in 0..n_samples {
126+
let weight = weights[i];
127+
for j in 0..n_factors {
128+
let id = flist[[i, j]];
129+
values[j * width + id] += weight;
114130
}
115-
xk_prev.copy_from_slice(&xk_curr);
116131
}
117132

118-
if !converged {
119-
not_converged.fetch_add(1, Ordering::SeqCst);
133+
Self {
134+
values,
135+
width,
120136
}
121-
Zip::from(&mut col).and(&xk_curr).for_each(|col_elm, &val| {
122-
*col_elm = val;
123-
});
124-
};
125-
126-
let mut res = Array2::<f64>::zeros((n_samples, n_features));
127-
128-
res.axis_iter_mut(ndarray::Axis(1))
129-
.into_par_iter()
130-
.enumerate()
131-
.for_each(process_column);
137+
}
132138

133-
let success = not_converged.load(Ordering::SeqCst) == 0;
134-
(res, success)
139+
fn factor_weight_slice(&self, factor_index: usize) -> &[f64] {
140+
&self.values[factor_index * self.width..(factor_index + 1) * self.width]
141+
}
135142
}
136143

137144

@@ -196,7 +203,6 @@ fn demean_impl(
196203
/// print(x_demeaned)
197204
/// print("Converged:", converged)
198205
/// ```
199-
200206
#[pyfunction]
201207
#[pyo3(signature = (x, flist, weights, tol=1e-8, maxiter=100_000))]
202208
pub fn _demean_rs(
@@ -207,13 +213,18 @@ pub fn _demean_rs(
207213
tol: f64,
208214
maxiter: usize,
209215
) -> PyResult<(Py<PyArray2<f64>>, bool)> {
210-
let x_arr = x.as_array();
211-
let flist_arr = flist.as_array();
212-
let weights_arr = weights.as_array();
213-
214-
let (out, success) =
215-
py.allow_threads(|| demean_impl(&x_arr, &flist_arr, &weights_arr, tol, maxiter));
216-
217-
let pyarray = PyArray2::from_owned_array(py, out);
218-
Ok((pyarray.into_py(py), success))
216+
let mut x_array = x.as_array().to_owned();
217+
let flist_array = flist.as_array();
218+
let weights_array = weights.as_array();
219+
220+
let converged = demean_impl(
221+
&mut x_array,
222+
flist_array,
223+
weights_array,
224+
tol,
225+
maxiter,
226+
);
227+
228+
let pyarray = PyArray2::from_owned_array(py, x_array);
229+
Ok((pyarray.into_py(py), converged))
219230
}

0 commit comments

Comments
 (0)