Skip to content

Optimise demean_rs #928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 132 additions & 121 deletions src/demean.rs
Original file line number Diff line number Diff line change
@@ -1,137 +1,144 @@
use ndarray::{Array2, ArrayView1, ArrayView2, Zip};
use ndarray::{Array2, ArrayView1, ArrayView2, ArrayViewMut1};
use numpy::{PyArray2, PyReadonlyArray1, PyReadonlyArray2};
use pyo3::prelude::*;
use rayon::prelude::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

mod internal {
pub(super) fn sad_converged(a: &[f64], b: &[f64], tol: f64) -> bool {
a.iter().zip(b).all(|(&x, &y)| (x - y).abs() < tol)

#[allow(non_snake_case)]
pub fn demean_impl(X: &mut Array2<f64>, D: ArrayView2<usize>, weights: ArrayView1<f64>, tol: f64, iterations: usize) -> bool {
let nsamples = X.nrows();
let nfactors = D.ncols();
let success = Arc::new(AtomicBool::new(true));
let group_weights = FactorGroupWeights::new(&D, &weights);

X.axis_iter_mut(ndarray::Axis(1))
.into_par_iter()
.for_each(|mut column| {
let mut demeaner = ColumnDemeaner::new(nsamples, group_weights.width);

for _ in 0..iterations {
for i in 0..nfactors {
demeaner.demean_column(
&mut column,
&weights,
&D.column(i),
group_weights.factor_weight_slice(i)
);
}

demeaner.check_convergence(&column.view(), tol);
if demeaner.converged {
break;
}
}

if !demeaner.converged {
// We can use a relaxed ordering since we only ever go from true to false
// and it doesn't matter how many times we do this.
success.store(false, Ordering::Relaxed);
}
});

success.load(Ordering::Relaxed)
}

// The column demeaner is in charge of subtracting group means until convergence.
struct ColumnDemeaner {
converged: bool,
checkpoint: Vec<f64>,
group_sums: Vec<f64>,
}

impl ColumnDemeaner {
fn new(n: usize, k: usize) -> Self {
Self {
converged: false,
checkpoint: vec![0.0; n],
group_sums: vec![0.0; k],
}
}

pub(super) fn subtract_weighted_group_mean(
x: &mut [f64],
sample_weights: &[f64],
group_ids: &[usize],
fn demean_column(
&mut self,
x: &mut ArrayViewMut1<f64>,
weights: &ArrayView1<f64>,
groups: &ArrayView1<usize>,
group_weights: &[f64],
group_weighted_sums: &mut [f64],
) {
group_weighted_sums.fill(0.0);
self.group_sums.fill(0.0);

// Accumulate weighted sums per group
x.iter()
.zip(sample_weights)
.zip(group_ids)
.for_each(|((&xi, &wi), &gid)| {
group_weighted_sums[gid] += wi * xi;
});
// Compute group sums
for ((&xi, &wi), &gid) in x.iter().zip(weights).zip(groups) {
self.group_sums[gid] += wi * xi;
}

// Compute group means
let group_means: Vec<f64> = group_weighted_sums
.iter()
.zip(group_weights)
.map(|(&sum, &weight)| sum / weight)
.collect();
// Convert sums to means
self.group_sums
.iter_mut()
.zip(group_weights.iter())
.for_each(|(sum, &weight)| {
*sum /= weight
});

// Subtract means from each sample
x.iter_mut().zip(group_ids).for_each(|(xi, &gid)| {
*xi -= group_means[gid];
});
// Subtract group means
for (xi, &gid) in x.iter_mut().zip(groups) {
*xi -= self.group_sums[gid] // Really these are means now
}
}

pub(super) fn calc_group_weights(
sample_weights: &[f64],
group_ids: &[usize],
n_samples: usize,
n_factors: usize,
n_groups: usize,
) -> Vec<f64> {
let mut group_weights = vec![0.0; n_factors * n_groups];
for i in 0..n_samples {
let weight = sample_weights[i];
for j in 0..n_factors {
let id = group_ids[i * n_factors + j];
group_weights[j * n_groups + id] += weight;
}
}
group_weights

// Check elementwise convergence and update checkpoint
fn check_convergence(
&mut self,
x: &ArrayView1<f64>,
tol: f64,
) {
self.converged = true; // Innocent until proven guilty
x.iter()
.zip(self.checkpoint.iter_mut())
.for_each(|(&xi, cp)| {
if (xi - *cp).abs() > tol {
self.converged = false; // Guilty!
}
*cp = xi; // Update checkpoint
});
}
}

fn demean_impl(
x: &ArrayView2<f64>,
flist: &ArrayView2<usize>,
weights: &ArrayView1<f64>,
tol: f64,
maxiter: usize,
) -> (Array2<f64>, bool) {
let (n_samples, n_features) = x.dim();
let n_factors = flist.ncols();
let n_groups = flist.iter().cloned().max().unwrap() + 1;

let sample_weights: Vec<f64> = weights.iter().cloned().collect();
let group_ids: Vec<usize> = flist.iter().cloned().collect();
let group_weights =
internal::calc_group_weights(&sample_weights, &group_ids, n_samples, n_factors, n_groups);

let not_converged = Arc::new(AtomicUsize::new(0));

// Precompute slices of group_ids for each factor
let group_ids_by_factor: Vec<Vec<usize>> = (0..n_factors)
.map(|j| {
(0..n_samples)
.map(|i| group_ids[i * n_factors + j])
.collect()
})
.collect();

// Precompute group weight slices
let group_weight_slices: Vec<&[f64]> = (0..n_factors)
.map(|j| &group_weights[j * n_groups..(j + 1) * n_groups])
.collect();

let process_column = |(k, mut col): (usize, ndarray::ArrayViewMut1<f64>)| {
let mut xk_curr: Vec<f64> = (0..n_samples).map(|i| x[[i, k]]).collect();
let mut xk_prev: Vec<f64> = xk_curr.iter().map(|&v| v - 1.0).collect();
let mut gw_sums = vec![0.0; n_groups];

let mut converged = false;
for _ in 0..maxiter {
for j in 0..n_factors {
internal::subtract_weighted_group_mean(
&mut xk_curr,
&sample_weights,
&group_ids_by_factor[j],
group_weight_slices[j],
&mut gw_sums,
);
}
// Instead of recomputing the denominators for the weighted group averages every time,
// we'll precompute them and store them in a grid-like structure. The grid will have
// dimensions (m, k) where m is the number of factors and k is the maximum group ID.
struct FactorGroupWeights {
values: Vec<f64>,
width: usize,
}

impl FactorGroupWeights {
fn new(flist: &ArrayView2<usize>, weights: &ArrayView1<f64>) -> Self {
let n_samples = flist.nrows();
let n_factors = flist.ncols();
let width = flist.iter().max().unwrap() + 1;

if internal::sad_converged(&xk_curr, &xk_prev, tol) {
converged = true;
break;
let mut values = vec![0.0; n_factors * width];
for i in 0..n_samples {
let weight = weights[i];
for j in 0..n_factors {
let id = flist[[i, j]];
values[j * width + id] += weight;
}
xk_prev.copy_from_slice(&xk_curr);
}

if !converged {
not_converged.fetch_add(1, Ordering::SeqCst);
Self {
values,
width,
}
Zip::from(&mut col).and(&xk_curr).for_each(|col_elm, &val| {
*col_elm = val;
});
};

let mut res = Array2::<f64>::zeros((n_samples, n_features));

res.axis_iter_mut(ndarray::Axis(1))
.into_par_iter()
.enumerate()
.for_each(process_column);
}

let success = not_converged.load(Ordering::SeqCst) == 0;
(res, success)
fn factor_weight_slice(&self, factor_index: usize) -> &[f64] {
&self.values[factor_index * self.width..(factor_index + 1) * self.width]
}
}


Expand Down Expand Up @@ -196,7 +203,6 @@ fn demean_impl(
/// print(x_demeaned)
/// print("Converged:", converged)
/// ```

#[pyfunction]
#[pyo3(signature = (x, flist, weights, tol=1e-8, maxiter=100_000))]
pub fn _demean_rs(
Expand All @@ -207,13 +213,18 @@ pub fn _demean_rs(
tol: f64,
maxiter: usize,
) -> PyResult<(Py<PyArray2<f64>>, bool)> {
let x_arr = x.as_array();
let flist_arr = flist.as_array();
let weights_arr = weights.as_array();

let (out, success) =
py.allow_threads(|| demean_impl(&x_arr, &flist_arr, &weights_arr, tol, maxiter));

let pyarray = PyArray2::from_owned_array(py, out);
Ok((pyarray.into_py(py), success))
let mut x_array = x.as_array().to_owned();
let flist_array = flist.as_array();
let weights_array = weights.as_array();

let converged = demean_impl(
&mut x_array,
flist_array,
weights_array,
tol,
maxiter,
);

let pyarray = PyArray2::from_owned_array(py, x_array);
Ok((pyarray.into_py(py), converged))
}
Loading