Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve thread-safety #50

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/rbeta.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ double rbeta(double aa, double bb)
int qsame;
/* FIXME: Keep Globals (properly) for threading */
/* Uses these GLOBALS to save time when many rv's are generated : */
static double beta, gamma, delta, k1, k2;
static double olda = -1.0;
static double oldb = -1.0;
_Thread_local static double beta, gamma, delta, k1, k2;
_Thread_local static double olda = -1.0;
_Thread_local static double oldb = -1.0;

/* Test if we need new "initializing" */
qsame = (olda == aa) && (oldb == bb);
Expand Down
10 changes: 5 additions & 5 deletions src/rbinom.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ double rbinom(double nin, double pp)
{
/* FIXME: These should become THREAD_specific globals : */

static double c, fm, npq, p1, p2, p3, p4, qn;
static double xl, xll, xlr, xm, xr;
_Thread_local static double c, fm, npq, p1, p2, p3, p4, qn;
_Thread_local static double xl, xll, xlr, xm, xr;

static double psave = -1.0;
static int nsave = -1;
static int m;
_Thread_local static double psave = -1.0;
_Thread_local static int nsave = -1;
_Thread_local static int m;

double f, f1, f2, u, v, w, w2, x, x1, x2, z, z2;
double p, q, np, g, r, al, alv, amaxp, ffm, ynorm;
Expand Down
10 changes: 5 additions & 5 deletions src/rgamma.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ double rgamma(double a, double scale)
const static double a6 = -0.1367177;
const static double a7 = 0.1233795;

/* State variables [FIXME for threading!] :*/
static double aa = 0.;
static double aaa = 0.;
static double s, s2, d; /* no. 1 (step 1) */
static double q0, b, si, c;/* no. 2 (step 4) */
/* State variables :*/
_Thread_local static double aa = 0.;
_Thread_local static double aaa = 0.;
_Thread_local static double s, s2, d; /* no. 1 (step 1) */
_Thread_local static double q0, b, si, c;/* no. 2 (step 4) */

double e, p, q, r, t, u, v, w, x, ret_val;

Expand Down
13 changes: 6 additions & 7 deletions src/rhyper.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,15 @@ double rhyper(double nn1in, double nn2in, double kkin)
int ix; // return value (coerced to double at the very end)
Rboolean setup1, setup2;

/* These should become 'thread_local globals' : */
static int ks = -1, n1s = -1, n2s = -1;
static int m, minjx, maxjx;
static int k, n1, n2; // <- not allowing larger integer par
static double N;
_Thread_local static int ks = -1, n1s = -1, n2s = -1;
_Thread_local static int m, minjx, maxjx;
_Thread_local static int k, n1, n2; // <- not allowing larger integer par
_Thread_local static double N;

// II :
static double w;
_Thread_local static double w;
// III:
static double a, d, s, xl, xr, kl, kr, lamdl, lamdr, p1, p2, p3;
_Thread_local static double a, d, s, xl, xr, kl, kr, lamdl, lamdr, p1, p2, p3;

/* check parameter validity */

Expand Down
10 changes: 5 additions & 5 deletions src/rpois.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ double rpois(double mu)
};

/* These are static --- persistent between calls for same mu : */
static int l, m;
_Thread_local static int l, m;

static double b1, b2, c, c0, c1, c2, c3;
static double pp[36], p0, p, q, s, d, omega;
static double big_l;/* integer "w/o overflow" */
static double muprev = 0., muprev2 = 0.;/*, muold = 0.*/
_Thread_local static double b1, b2, c, c0, c1, c2, c3;
_Thread_local static double pp[36], p0, p, q, s, d, omega;
_Thread_local static double big_l;/* integer "w/o overflow" */
_Thread_local static double muprev = 0., muprev2 = 0.;/*, muold = 0.*/

/* Local Vars [initialize some for -Wall]: */
double del, difmuk= 0., E= 0., fk= 0., fx, fy, g, px, py, t, u= 0., v, x;
Expand Down
4 changes: 2 additions & 2 deletions src/signrank.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
#include "nmath.h"
#include "dpq.h"

static double *w;
static int allocated_n;
_Thread_local static double *w;
_Thread_local static int allocated_n;

static void
w_free(void)
Expand Down
2 changes: 1 addition & 1 deletion src/sunif.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

/* A version of Marsaglia-MultiCarry */

static unsigned int I1=1234, I2=5678;
_Thread_local static unsigned int I1=1234, I2=5678;

void set_seed(unsigned int i1, unsigned int i2)
{
Expand Down
12 changes: 6 additions & 6 deletions src/toms708.c
Original file line number Diff line number Diff line change
Expand Up @@ -1491,12 +1491,12 @@ double rexpm1(double x)
/* EVALUATION OF THE FUNCTION EXP(X) - 1 */
/* ----------------------------------------------------------------------- */

static double p1 = 9.14041914819518e-10;
static double p2 = .0238082361044469;
static double q1 = -.499999999085958;
static double q2 = .107141568980644;
static double q3 = -.0119041179760821;
static double q4 = 5.95130811860248e-4;
_Thread_local static double p1 = 9.14041914819518e-10;
_Thread_local static double p2 = .0238082361044469;
_Thread_local static double q1 = -.499999999085958;
_Thread_local static double q2 = .107141568980644;
_Thread_local static double q3 = -.0119041179760821;
_Thread_local static double q4 = 5.95130811860248e-4;

if (fabs(x) <= 0.15) {
return x * (((p2 * x + p1) * x + 1.) /
Expand Down
4 changes: 2 additions & 2 deletions src/wilcox.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
#include <R_ext/Utils.h>
#endif

static double ***w; /* to store cwilcox(i,j,k) -> w[i][j][k] */
static int allocated_m, allocated_n;
_Thread_local static double ***w; /* to store cwilcox(i,j,k) -> w[i][j][k] */
_Thread_local static int allocated_m, allocated_n;

static void
w_free(int m, int n)
Expand Down
43 changes: 43 additions & 0 deletions test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,46 @@ unsafe_store!(cglobal((:exp_rand_ptr, libRmath), Ptr{Cvoid}),
@test ccall((:dbeta, libRmath), Float64, (Float64, Float64, Float64, Int32), 0.5, 0.1, 5.0, 0) 0.014267678091051986
@test 0 <= ccall((:rbeta, libRmath), Float64, (Float64, Float64), 0.1, 5.0) <= 1.0
end

@testset "rhyper" begin
# double rhyper(double nn1in, double nn2in, double kkin)
Nred = 30.0
Nblue = 40.0
Npulled = 5.0

hyper_samples = [
ccall((:rhyper, libRmath), Float64, (Float64, Float64, Float64), Nred, Nblue, Npulled)
for _ in 1:1_000_000
]
expected_mean = Npulled * Nred / (Nred + Nblue)
sample_mean = sum(hyper_samples) / length(hyper_samples)
@test sample_mean expected_mean rtol = 0.01

N = (Nred + Nblue)
expected_variance = Npulled * Nred * (N - Nred) * (N - Npulled) / (N * N * (N - 1))
sample_variance = 1 / (length(hyper_samples)) * sum((hyper_samples .- sample_mean) .^ 2)
@test sample_variance expected_variance rtol = 0.01
end

function sample_KkC(n; N, Q)
K = rand([1,2,3,4,5])
k = ccall(
(:rhyper, libRmath), Float64, (Float64, Float64, Float64),
K, N-K, n
)
return k
end

@testset "fulll" begin
function f(Q)
objective(n) = [sample_KkC(n; N = 819_200, Q) for _ = 1:100]
vals = [10, 100]
objective.(vals)
end

Qs = [0.05, 0.055, 0.1, 0.2, 0.3]

Threads.@threads for i in eachindex(Qs)
f(Qs[i])
end
end