Skip to content

Commit

Permalink
Merge pull request #50 from adomasbaliuka/adomas_towards_threadlocal
Browse files Browse the repository at this point in the history
Improve thread-safety
  • Loading branch information
ViralBShah authored Aug 12, 2024
2 parents d560159 + 575c153 commit f707fe5
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 36 deletions.
6 changes: 3 additions & 3 deletions src/rbeta.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ double rbeta(double aa, double bb)
int qsame;
/* FIXME: Keep Globals (properly) for threading */
/* Uses these GLOBALS to save time when many rv's are generated : */
static double beta, gamma, delta, k1, k2;
static double olda = -1.0;
static double oldb = -1.0;
_Thread_local static double beta, gamma, delta, k1, k2;
_Thread_local static double olda = -1.0;
_Thread_local static double oldb = -1.0;

/* Test if we need new "initializing" */
qsame = (olda == aa) && (oldb == bb);
Expand Down
10 changes: 5 additions & 5 deletions src/rbinom.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ double rbinom(double nin, double pp)
{
/* FIXME: These should become THREAD_specific globals : */

static double c, fm, npq, p1, p2, p3, p4, qn;
static double xl, xll, xlr, xm, xr;
_Thread_local static double c, fm, npq, p1, p2, p3, p4, qn;
_Thread_local static double xl, xll, xlr, xm, xr;

static double psave = -1.0;
static int nsave = -1;
static int m;
_Thread_local static double psave = -1.0;
_Thread_local static int nsave = -1;
_Thread_local static int m;

double f, f1, f2, u, v, w, w2, x, x1, x2, z, z2;
double p, q, np, g, r, al, alv, amaxp, ffm, ynorm;
Expand Down
10 changes: 5 additions & 5 deletions src/rgamma.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ double rgamma(double a, double scale)
const static double a6 = -0.1367177;
const static double a7 = 0.1233795;

/* State variables [FIXME for threading!] :*/
static double aa = 0.;
static double aaa = 0.;
static double s, s2, d; /* no. 1 (step 1) */
static double q0, b, si, c;/* no. 2 (step 4) */
/* State variables :*/
_Thread_local static double aa = 0.;
_Thread_local static double aaa = 0.;
_Thread_local static double s, s2, d; /* no. 1 (step 1) */
_Thread_local static double q0, b, si, c;/* no. 2 (step 4) */

double e, p, q, r, t, u, v, w, x, ret_val;

Expand Down
13 changes: 6 additions & 7 deletions src/rhyper.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,15 @@ double rhyper(double nn1in, double nn2in, double kkin)
int ix; // return value (coerced to double at the very end)
Rboolean setup1, setup2;

/* These should become 'thread_local globals' : */
static int ks = -1, n1s = -1, n2s = -1;
static int m, minjx, maxjx;
static int k, n1, n2; // <- not allowing larger integer par
static double N;
_Thread_local static int ks = -1, n1s = -1, n2s = -1;
_Thread_local static int m, minjx, maxjx;
_Thread_local static int k, n1, n2; // <- not allowing larger integer par
_Thread_local static double N;

// II :
static double w;
_Thread_local static double w;
// III:
static double a, d, s, xl, xr, kl, kr, lamdl, lamdr, p1, p2, p3;
_Thread_local static double a, d, s, xl, xr, kl, kr, lamdl, lamdr, p1, p2, p3;

/* check parameter validity */

Expand Down
10 changes: 5 additions & 5 deletions src/rpois.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ double rpois(double mu)
};

/* These are static --- persistent between calls for same mu : */
static int l, m;
_Thread_local static int l, m;

static double b1, b2, c, c0, c1, c2, c3;
static double pp[36], p0, p, q, s, d, omega;
static double big_l;/* integer "w/o overflow" */
static double muprev = 0., muprev2 = 0.;/*, muold = 0.*/
_Thread_local static double b1, b2, c, c0, c1, c2, c3;
_Thread_local static double pp[36], p0, p, q, s, d, omega;
_Thread_local static double big_l;/* integer "w/o overflow" */
_Thread_local static double muprev = 0., muprev2 = 0.;/*, muold = 0.*/

/* Local Vars [initialize some for -Wall]: */
double del, difmuk= 0., E= 0., fk= 0., fx, fy, g, px, py, t, u= 0., v, x;
Expand Down
4 changes: 2 additions & 2 deletions src/signrank.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
#include "nmath.h"
#include "dpq.h"

static double *w;
static int allocated_n;
_Thread_local static double *w;
_Thread_local static int allocated_n;

static void
w_free(void)
Expand Down
2 changes: 1 addition & 1 deletion src/sunif.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

/* A version of Marsaglia-MultiCarry */

static unsigned int I1=1234, I2=5678;
_Thread_local static unsigned int I1=1234, I2=5678;

void set_seed(unsigned int i1, unsigned int i2)
{
Expand Down
12 changes: 6 additions & 6 deletions src/toms708.c
Original file line number Diff line number Diff line change
Expand Up @@ -1491,12 +1491,12 @@ double rexpm1(double x)
/* EVALUATION OF THE FUNCTION EXP(X) - 1 */
/* ----------------------------------------------------------------------- */

static double p1 = 9.14041914819518e-10;
static double p2 = .0238082361044469;
static double q1 = -.499999999085958;
static double q2 = .107141568980644;
static double q3 = -.0119041179760821;
static double q4 = 5.95130811860248e-4;
_Thread_local static double p1 = 9.14041914819518e-10;
_Thread_local static double p2 = .0238082361044469;
_Thread_local static double q1 = -.499999999085958;
_Thread_local static double q2 = .107141568980644;
_Thread_local static double q3 = -.0119041179760821;
_Thread_local static double q4 = 5.95130811860248e-4;

if (fabs(x) <= 0.15) {
return x * (((p2 * x + p1) * x + 1.) /
Expand Down
4 changes: 2 additions & 2 deletions src/wilcox.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
#include <R_ext/Utils.h>
#endif

static double ***w; /* to store cwilcox(i,j,k) -> w[i][j][k] */
static int allocated_m, allocated_n;
_Thread_local static double ***w; /* to store cwilcox(i,j,k) -> w[i][j][k] */
_Thread_local static int allocated_m, allocated_n;

static void
w_free(int m, int n)
Expand Down
43 changes: 43 additions & 0 deletions test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,46 @@ unsafe_store!(cglobal((:exp_rand_ptr, libRmath), Ptr{Cvoid}),
@test ccall((:dbeta, libRmath), Float64, (Float64, Float64, Float64, Int32), 0.5, 0.1, 5.0, 0) 0.014267678091051986
@test 0 <= ccall((:rbeta, libRmath), Float64, (Float64, Float64), 0.1, 5.0) <= 1.0
end

@testset "rhyper" begin
# double rhyper(double nn1in, double nn2in, double kkin)
Nred = 30.0
Nblue = 40.0
Npulled = 5.0

hyper_samples = [
ccall((:rhyper, libRmath), Float64, (Float64, Float64, Float64), Nred, Nblue, Npulled)
for _ in 1:1_000_000
]
expected_mean = Npulled * Nred / (Nred + Nblue)
sample_mean = sum(hyper_samples) / length(hyper_samples)
@test sample_mean expected_mean rtol = 0.01

N = (Nred + Nblue)
expected_variance = Npulled * Nred * (N - Nred) * (N - Npulled) / (N * N * (N - 1))
sample_variance = 1 / (length(hyper_samples)) * sum((hyper_samples .- sample_mean) .^ 2)
@test sample_variance expected_variance rtol = 0.01
end

function sample_KkC(n; N, Q)
K = rand([1,2,3,4,5])
k = ccall(
(:rhyper, libRmath), Float64, (Float64, Float64, Float64),
K, N-K, n
)
return k
end

@testset "fulll" begin
function f(Q)
objective(n) = [sample_KkC(n; N = 819_200, Q) for _ = 1:100]
vals = [10, 100]
objective.(vals)
end

Qs = [0.05, 0.055, 0.1, 0.2, 0.3]

Threads.@threads for i in eachindex(Qs)
f(Qs[i])
end
end

0 comments on commit f707fe5

Please sign in to comment.