From c638fb061b8097c7b43e6a9fbb0fdfe6baea34d4 Mon Sep 17 00:00:00 2001 From: kinto-b Date: Thu, 5 Jun 2025 22:20:17 +0200 Subject: [PATCH] Optimise demean_rs --- src/demean.rs | 253 ++++++++++++++++++++++++++------------------------ 1 file changed, 132 insertions(+), 121 deletions(-) diff --git a/src/demean.rs b/src/demean.rs index b40f8a64c..6b4a326fd 100644 --- a/src/demean.rs +++ b/src/demean.rs @@ -1,137 +1,144 @@ -use ndarray::{Array2, ArrayView1, ArrayView2, Zip}; +use ndarray::{Array2, ArrayView1, ArrayView2, ArrayViewMut1}; use numpy::{PyArray2, PyReadonlyArray1, PyReadonlyArray2}; use pyo3::prelude::*; use rayon::prelude::*; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -mod internal { - pub(super) fn sad_converged(a: &[f64], b: &[f64], tol: f64) -> bool { - a.iter().zip(b).all(|(&x, &y)| (x - y).abs() < tol) + +#[allow(non_snake_case)] +pub fn demean_impl(X: &mut Array2, D: ArrayView2, weights: ArrayView1, tol: f64, iterations: usize) -> bool { + let nsamples = X.nrows(); + let nfactors = D.ncols(); + let success = Arc::new(AtomicBool::new(true)); + let group_weights = FactorGroupWeights::new(&D, &weights); + + X.axis_iter_mut(ndarray::Axis(1)) + .into_par_iter() + .for_each(|mut column| { + let mut demeaner = ColumnDemeaner::new(nsamples, group_weights.width); + + for _ in 0..iterations { + for i in 0..nfactors { + demeaner.demean_column( + &mut column, + &weights, + &D.column(i), + group_weights.factor_weight_slice(i) + ); + } + + demeaner.check_convergence(&column.view(), tol); + if demeaner.converged { + break; + } + } + + if !demeaner.converged { + // We can use a relaxed ordering since we only ever go from true to false + // and it doesn't matter how many times we do this. + success.store(false, Ordering::Relaxed); + } + }); + + success.load(Ordering::Relaxed) +} + +// The column demeaner is in charge of subtracting group means until convergence. +struct ColumnDemeaner { + converged: bool, + checkpoint: Vec, + group_sums: Vec, +} + +impl ColumnDemeaner { + fn new(n: usize, k: usize) -> Self { + Self { + converged: false, + checkpoint: vec![0.0; n], + group_sums: vec![0.0; k], + } } - pub(super) fn subtract_weighted_group_mean( - x: &mut [f64], - sample_weights: &[f64], - group_ids: &[usize], + fn demean_column( + &mut self, + x: &mut ArrayViewMut1, + weights: &ArrayView1, + groups: &ArrayView1, group_weights: &[f64], - group_weighted_sums: &mut [f64], ) { - group_weighted_sums.fill(0.0); + self.group_sums.fill(0.0); - // Accumulate weighted sums per group - x.iter() - .zip(sample_weights) - .zip(group_ids) - .for_each(|((&xi, &wi), &gid)| { - group_weighted_sums[gid] += wi * xi; - }); + // Compute group sums + for ((&xi, &wi), &gid) in x.iter().zip(weights).zip(groups) { + self.group_sums[gid] += wi * xi; + } - // Compute group means - let group_means: Vec = group_weighted_sums - .iter() - .zip(group_weights) - .map(|(&sum, &weight)| sum / weight) - .collect(); + // Convert sums to means + self.group_sums + .iter_mut() + .zip(group_weights.iter()) + .for_each(|(sum, &weight)| { + *sum /= weight + }); - // Subtract means from each sample - x.iter_mut().zip(group_ids).for_each(|(xi, &gid)| { - *xi -= group_means[gid]; - }); + // Subtract group means + for (xi, &gid) in x.iter_mut().zip(groups) { + *xi -= self.group_sums[gid] // Really these are means now + } } - pub(super) fn calc_group_weights( - sample_weights: &[f64], - group_ids: &[usize], - n_samples: usize, - n_factors: usize, - n_groups: usize, - ) -> Vec { - let mut group_weights = vec![0.0; n_factors * n_groups]; - for i in 0..n_samples { - let weight = sample_weights[i]; - for j in 0..n_factors { - let id = group_ids[i * n_factors + j]; - group_weights[j * n_groups + id] += weight; - } - } - group_weights + + // Check elementwise convergence and update checkpoint + fn check_convergence( + &mut self, + x: &ArrayView1, + tol: f64, + ) { + self.converged = true; // Innocent until proven guilty + x.iter() + .zip(self.checkpoint.iter_mut()) + .for_each(|(&xi, cp)| { + if (xi - *cp).abs() > tol { + self.converged = false; // Guilty! + } + *cp = xi; // Update checkpoint + }); } } -fn demean_impl( - x: &ArrayView2, - flist: &ArrayView2, - weights: &ArrayView1, - tol: f64, - maxiter: usize, -) -> (Array2, bool) { - let (n_samples, n_features) = x.dim(); - let n_factors = flist.ncols(); - let n_groups = flist.iter().cloned().max().unwrap() + 1; - - let sample_weights: Vec = weights.iter().cloned().collect(); - let group_ids: Vec = flist.iter().cloned().collect(); - let group_weights = - internal::calc_group_weights(&sample_weights, &group_ids, n_samples, n_factors, n_groups); - - let not_converged = Arc::new(AtomicUsize::new(0)); - - // Precompute slices of group_ids for each factor - let group_ids_by_factor: Vec> = (0..n_factors) - .map(|j| { - (0..n_samples) - .map(|i| group_ids[i * n_factors + j]) - .collect() - }) - .collect(); - - // Precompute group weight slices - let group_weight_slices: Vec<&[f64]> = (0..n_factors) - .map(|j| &group_weights[j * n_groups..(j + 1) * n_groups]) - .collect(); - - let process_column = |(k, mut col): (usize, ndarray::ArrayViewMut1)| { - let mut xk_curr: Vec = (0..n_samples).map(|i| x[[i, k]]).collect(); - let mut xk_prev: Vec = xk_curr.iter().map(|&v| v - 1.0).collect(); - let mut gw_sums = vec![0.0; n_groups]; - - let mut converged = false; - for _ in 0..maxiter { - for j in 0..n_factors { - internal::subtract_weighted_group_mean( - &mut xk_curr, - &sample_weights, - &group_ids_by_factor[j], - group_weight_slices[j], - &mut gw_sums, - ); - } +// Instead of recomputing the denominators for the weighted group averages every time, +// we'll precompute them and store them in a grid-like structure. The grid will have +// dimensions (m, k) where m is the number of factors and k is the maximum group ID. +struct FactorGroupWeights { + values: Vec, + width: usize, +} + +impl FactorGroupWeights { + fn new(flist: &ArrayView2, weights: &ArrayView1) -> Self { + let n_samples = flist.nrows(); + let n_factors = flist.ncols(); + let width = flist.iter().max().unwrap() + 1; - if internal::sad_converged(&xk_curr, &xk_prev, tol) { - converged = true; - break; + let mut values = vec![0.0; n_factors * width]; + for i in 0..n_samples { + let weight = weights[i]; + for j in 0..n_factors { + let id = flist[[i, j]]; + values[j * width + id] += weight; } - xk_prev.copy_from_slice(&xk_curr); } - if !converged { - not_converged.fetch_add(1, Ordering::SeqCst); + Self { + values, + width, } - Zip::from(&mut col).and(&xk_curr).for_each(|col_elm, &val| { - *col_elm = val; - }); - }; - - let mut res = Array2::::zeros((n_samples, n_features)); - - res.axis_iter_mut(ndarray::Axis(1)) - .into_par_iter() - .enumerate() - .for_each(process_column); + } - let success = not_converged.load(Ordering::SeqCst) == 0; - (res, success) + fn factor_weight_slice(&self, factor_index: usize) -> &[f64] { + &self.values[factor_index * self.width..(factor_index + 1) * self.width] + } } @@ -196,7 +203,6 @@ fn demean_impl( /// print(x_demeaned) /// print("Converged:", converged) /// ``` - #[pyfunction] #[pyo3(signature = (x, flist, weights, tol=1e-8, maxiter=100_000))] pub fn _demean_rs( @@ -207,13 +213,18 @@ pub fn _demean_rs( tol: f64, maxiter: usize, ) -> PyResult<(Py>, bool)> { - let x_arr = x.as_array(); - let flist_arr = flist.as_array(); - let weights_arr = weights.as_array(); - - let (out, success) = - py.allow_threads(|| demean_impl(&x_arr, &flist_arr, &weights_arr, tol, maxiter)); - - let pyarray = PyArray2::from_owned_array(py, out); - Ok((pyarray.into_py(py), success)) + let mut x_array = x.as_array().to_owned(); + let flist_array = flist.as_array(); + let weights_array = weights.as_array(); + + let converged = demean_impl( + &mut x_array, + flist_array, + weights_array, + tol, + maxiter, + ); + + let pyarray = PyArray2::from_owned_array(py, x_array); + Ok((pyarray.into_py(py), converged)) }