Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sampling from Hypergeometric thread-safe #46

Closed
adomasbaliuka opened this issue Jun 5, 2024 · 9 comments
Closed

Make sampling from Hypergeometric thread-safe #46

adomasbaliuka opened this issue Jun 5, 2024 · 9 comments

Comments

@adomasbaliuka
Copy link
Contributor

This code here says the static variables "should become 'thread_local globals' ". Is there anything preventing just putting _Thread_local in the declaration?

The current version leads to sampling from the hypergeometric distribution (rhyper) being broken when used in a threaded context. I mentioned this previously at JuliaStats/Distributions.jl#1829.

@ViralBShah
Copy link
Contributor

Happy to accept a PR as a patch file that gets applied (so that we can keep updating to new R versions and then applying the patch).

@adomasbaliuka
Copy link
Contributor Author

I don't understand what you mean by "patch file".

@ViralBShah
Copy link
Contributor

@adomasbaliuka
Copy link
Contributor Author

adomasbaliuka commented Aug 11, 2024

I guess that means you don't use pull requests? Fine I guess...

Before changing anything, I tried to reproduce the error explicitly in the test.jl file. I put this at the bottom:

@testset "rhyper" begin
    # double rhyper(double nn1in, double nn2in, double kkin)
    Nred = 30.0
    Nblue = 40.0
    Npulled = 5.0

    hyper_samples = [
        ccall((:rhyper, libRmath), Float64, (Float64, Float64, Float64), Nred, Nblue, Npulled)
        for _ in 1:1_000_000
    ]
    expected_mean = Npulled * Nred / (Nred + Nblue)
    sample_mean = sum(hyper_samples) / length(hyper_samples)
    @test sample_mean  expected_mean rtol = 0.001

    N = (Nred + Nblue)
    expected_variance = Npulled * Nred * (N - Nred) * (N - Npulled) / (N * N * (N - 1))
    sample_variance = 1 / (length(hyper_samples)) * sum((hyper_samples .- sample_mean) .^ 2)
    @test sample_variance  expected_variance rtol = 0.001
end

@testset "rhyper_multithreaded" begin
    # double rhyper(double nn1in, double nn2in, double kkin)
    Nred = 30.0
    Nblue = 40.0
    Npulled = 5.0

    hyper_samples = Vector{Float64}(undef, 10_000_000)
    Threads.@threads for i in eachindex(hyper_samples)
        hyper_samples[i] = ccall(
            (:rhyper, libRmath), Float64, (Float64, Float64, Float64),
            Nred, Nblue, Npulled
        )
    end

    expected_mean = Npulled * Nred / (Nred + Nblue)
    sample_mean = sum(hyper_samples) / length(hyper_samples)
    @test sample_mean  expected_mean rtol = 0.001

    N = (Nred + Nblue)
    expected_variance = Npulled * Nred * (N - Nred) * (N - Npulled) / (N * N * (N - 1))
    sample_variance = 1 / (length(hyper_samples)) * sum((hyper_samples .- sample_mean) .^ 2)
    @test sample_variance  expected_variance rtol = 0.001
end

To my surprise, this seems to work correctly even when using threads.

That's contrary to my expectations. I'm also not sure why the original issue with Distributions.jl behaves differently, since it seems to be just calling this library the same way as I do (defined here and then here)

Edit:

An example that more closely reproduces the original issue does show the errors:

using Distributions
function sample_KkC(n; N, Q)
    total_errors = Distributions.Binomial(N, Q)
    K = rand(total_errors)
    k = ccall(
        (:rhyper, libRmath), Float64, (Float64, Float64, Float64),
        K, N-K, n
    )
    return k
end

@testset "fulll" begin
    function f(Q)
        objective(n) = [sample_KkC(n; N = 819_200, Q) for _ = 1:100]
        vals = [10, 100]
        objective.(vals)
    end

    Qs = [0.05, 0.055, 0.1, 0.2, 0.3]

    Threads.@threads for i in eachindex(Qs)
        f(Qs[i])
    end
end

@adomasbaliuka
Copy link
Contributor Author

adomasbaliuka commented Aug 11, 2024

And making the static variables _Thread_local seems to fix the issue.

Patch:
patch.txt

Note: the patch contains my edits to test.jl, which you probably don't want, but they show the error and that it is fixed by the patch.

@adomasbaliuka
Copy link
Contributor Author

adomasbaliuka commented Aug 11, 2024

I found some more places where static variables seem to be intended "to be made thread-local". I did this in this new patch:
patch_all_statics_now_threadlocal.txt

I'm still confused if you want a PR or just the diff, so I made a PR as well...

@ViralBShah
Copy link
Contributor

ViralBShah commented Aug 11, 2024

Right we want a PR, in which this patch is applied. That way we keep carrying the patch and applying it every time we upgrade the Rmath version from the R distribution with make update.

If we merge that PR, those changes will be overwritten when we sync a new upstream Rmath version with make update. With a patch file checked in - we will apply the patch every time we update. Hopefully these things don't change frequently enough and the patch will work across releases.

@adomasbaliuka
Copy link
Contributor Author

Do I understand correctly that you don't need anything further from me?

@ViralBShah
Copy link
Contributor

Merged and added #51. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants