Skip to content

Commit

Permalink
Adapt transformed MCMC code to using BATContext
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Jul 6, 2023
1 parent 756827f commit 9cb2625
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 66 deletions.
22 changes: 10 additions & 12 deletions examples/dev-internal/transformed_example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@ using BAT.LinearAlgebra
using BAT.Distributions
using BAT.InverseFunctions
import BAT: TransformedMCMCIterator, TransformedAdaptiveMHTuning, TransformedRAMTuner, TransformedMHProposal, TransformedNoTransformedMCMCTempering, transformed_mcmc_step!!, TransformedMCMCTransformedSampleID
using BAT.Random123
using Random123
using AutoDiffOperators

import BAT: mcmc_iterate!, transformed_mcmc_iterate!, TransformedMCMCSampling

#ENV["JULIA_DEBUG"] = "BAT"

rng = Philox4x()
context = BATContext(ad = ADModule(:ForwardDiff))

posterior = BAT.example_posterior()

my_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000))


my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000), context)


density_notrafo = convert(BAT.AbstractMeasureOrDensity, posterior)
Expand All @@ -27,7 +26,7 @@ density, trafo = BAT.transform_and_unshape(PriorToGaussian(), density_notrafo)
c = BAT._approx_cov(density)
f = BAT.CustomTransform(Mul(c))

my_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), tuning_alg=TransformedAdaptiveMHTuning(), nchains=4, nsteps=4*100000, adaptive_transform=f))
my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), tuning_alg=TransformedAdaptiveMHTuning(), nchains=4, nsteps=4*100000, adaptive_transform=f), context)

my_samples = my_result.result

Expand All @@ -36,9 +35,9 @@ my_samples = my_result.result
using Plots
plot(my_samples)

r_mh = @time BAT.bat_sample_impl(rng, posterior, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true) )
r_mh = @time BAT.bat_sample_impl(posterior, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true), context)

r_hmc = @time BAT.bat_sample_impl(rng, posterior, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000) )
r_hmc = @time BAT.bat_sample_impl(posterior, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000), context)

plot(bat_sample(posterior).result)

Expand All @@ -60,9 +59,8 @@ posterior.likelihood.density._log_f(rand(prior2))
posterior2 = PosteriorDensity(BAT.logfuncdensity(posterior.likelihood.density._log_f), prior2)


@profview r_ram2 = @time BAT.bat_sample_impl(rng, posterior2, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000))

@profview r_mh2 = @time BAT.bat_sample_impl(rng, posterior2, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true) )
@profview r_ram2 = @time BAT.bat_sample_impl(posterior2, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000), context)

r_hmc2 = @time BAT.bat_sample_impl(rng, posterior2, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000) )
@profview r_mh2 = @time BAT.bat_sample_impl(posterior2, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true), context)

r_hmc2 = @time BAT.bat_sample_impl(posterior2, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000), context)
31 changes: 16 additions & 15 deletions src/samplers/transformed_mcmc/chain_pool_init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,33 @@ function _construct_chain(
id::Integer,
algorithm::TransformedMCMCSampling,
density::AbstractMeasureOrDensity,
initval_alg::InitvalAlgorithm
initval_alg::InitvalAlgorithm,
parent_context::BATContext
)
rng = AbstractRNG(rngpart, id)
v_init = bat_initval(rng, density, initval_alg).result

TransformedMCMCIterator(rng, algorithm, density, id, v_init)
new_context = set_rng(parent_context, AbstractRNG(rngpart, id))
v_init = bat_initval(density, initval_alg, new_context).result
return TransformedMCMCIterator(algorithm, density, id, v_init, new_context)

Check warning on line 45 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L43-L45

Added lines #L43 - L45 were not covered by tests
end

_gen_chains(
rngpart::RNGPartition,
ids::AbstractRange{<:Integer},
algorithm::TransformedMCMCSampling,
density::AbstractMeasureOrDensity,
initval_alg::InitvalAlgorithm
) = [_construct_chain(rngpart, id, algorithm, density, initval_alg) for id in ids]
initval_alg::InitvalAlgorithm,
context::BATContext
) = [_construct_chain(rngpart, id, algorithm, density, initval_alg, context) for id in ids]

Check warning on line 55 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L55

Added line #L55 was not covered by tests

#TODO
function mcmc_init!(

Check warning on line 58 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L58

Added line #L58 was not covered by tests
rng::AbstractRNG,
algorithm::TransformedMCMCSampling,
density::AbstractMeasureOrDensity,
nchains::Integer,
init_alg::TransformedMCMCChainPoolInit,
tuning_alg::TransformedMCMCTuningAlgorithm, # TODO: part of algorithm? # MCMCTuner
nonzero_weights::Bool,
callback::Function
callback::Function,
context::BATContext
)
@info "TransformedMCMCChainPoolInit: trying to generate $nchains viable MCMC chain(s)."

Check warning on line 68 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L68

Added line #L68 was not covered by tests

Expand All @@ -71,14 +72,15 @@ function mcmc_init!(
min_nviable::Int = minimum(init_alg.init_tries_per_chain) * nchains
max_ncandidates::Int = maximum(init_alg.init_tries_per_chain) * nchains

Check warning on line 73 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L72-L73

Added lines #L72 - L73 were not covered by tests

rngpart = RNGPartition(rng, Base.OneTo(max_ncandidates))
rngpart = RNGPartition(get_rng(context), Base.OneTo(max_ncandidates))

Check warning on line 75 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L75

Added line #L75 was not covered by tests

ncandidates::Int = 0

Check warning on line 77 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L77

Added line #L77 was not covered by tests

@debug "Generating dummy MCMC chain to determine chain, output and tuner types." #TODO: remove!

Check warning on line 79 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L79

Added line #L79 was not covered by tests

dummy_initval = unshaped(bat_initval(rng, density, InitFromTarget()).result, varshape(density))
dummy_chain = TransformedMCMCIterator(rng, algorithm, density, 1, dummy_initval)
dummy_context = deepcopy(context)
dummy_initval = unshaped(bat_initval(density, InitFromTarget(), dummy_context).result, varshape(density))
dummy_chain = TransformedMCMCIterator(algorithm, density, 1, dummy_initval, dummy_context)
dummy_tuner = get_tuner(tuning_alg, dummy_chain)
dummy_temperer = get_temperer(algorithm.tempering, density)

Check warning on line 85 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L81-L85

Added lines #L81 - L85 were not covered by tests

Expand All @@ -93,7 +95,7 @@ function mcmc_init!(
n = min(min_nviable, max_ncandidates - ncandidates)
@debug "Generating $n $(init_tries > 1 ? "additional " : "")candidate MCMC chain(s)."

Check warning on line 96 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L95-L96

Added lines #L95 - L96 were not covered by tests

new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg)
new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context)

Check warning on line 98 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L98

Added line #L98 was not covered by tests

filter!(isvalidchain, new_chains)

Check warning on line 100 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L100

Added line #L100 was not covered by tests

Expand Down Expand Up @@ -135,7 +137,6 @@ function mcmc_init!(
nsamples_thresh = floor(Int, 0.8 * median([nsamples(chain) for chain in viable_chains]))
good_idxs = findall(chain -> nsamples(chain) >= nsamples_thresh, viable_chains)
@debug "Found $(length(viable_chains)) MCMC chain(s) with at least $(nsamples_thresh) unique accepted samples."

Check warning on line 139 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L136-L139

Added lines #L136 - L139 were not covered by tests


append!(chains, view(viable_chains, good_idxs))
append!(tuners, view(viable_tuners, good_idxs))
Expand All @@ -153,7 +154,7 @@ function mcmc_init!(
tidxs = LinearIndices(chains)
n = length(tidxs)

Check warning on line 155 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L153-L155

Added lines #L153 - L155 were not covered by tests

modes = hcat(broadcast(samples -> Array(bat_findmode(rng, samples, MaxDensitySearch()).result), outputs)...)
modes = hcat(broadcast(samples -> Array(bat_findmode(samples, MaxDensitySearch(), context).result), outputs)...)

Check warning on line 157 in src/samplers/transformed_mcmc/chain_pool_init.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/chain_pool_init.jl#L157

Added line #L157 was not covered by tests

final_chains = similar(chains, 0)
final_tuners = similar(tuners, 0)
Expand Down
4 changes: 1 addition & 3 deletions src/samplers/transformed_mcmc/mcmc_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ end

# BAT.getmeasure(chain::SomeMCMCIter)::AbstractMeasureOrDensity

# BAT.getrng(chain::SomeMCMCIter)::AbstractRNG
# BAT.getcontext(chain::SomeMCMCIter)::BATContext

# BAT.mcmc_info(chain::SomeMCMCIter)::TransformedMCMCIteratorInfo

Expand Down Expand Up @@ -122,8 +122,6 @@ function getalgorithm end

function getmeasure end

function getrng end

function mcmc_info end

function nsteps end
Expand Down
47 changes: 27 additions & 20 deletions src/samplers/transformed_mcmc/mcmc_iterate.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
mutable struct TransformedMCMCIterator{
R<:AbstractRNG,
PR<:RNGPartition,
D<:BATMeasure,
F,
Q<:TransformedMCMCProposal,
SV<:DensitySampleVector,
S<:DensitySample,
CTX<:BATContext,
} <: MCMCIterator
rng::R
rngpart_cycle::PR
μ::D
f_transform::F
Expand All @@ -17,11 +16,12 @@ mutable struct TransformedMCMCIterator{
stepno::Int
n_accepted::Int
info::TransformedMCMCIteratorInfo
context::CTX
end

getmeasure(chain::TransformedMCMCIterator) = chain.μ

Check warning on line 22 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L22

Added line #L22 was not covered by tests

getrng(chain::TransformedMCMCIterator) = chain.rng
get_context(chain::TransformedMCMCIterator) = chain.context

Check warning on line 24 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L24

Added line #L24 was not covered by tests

mcmc_info(chain::TransformedMCMCIterator) = chain.info

Check warning on line 26 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L26

Added line #L26 was not covered by tests

Expand All @@ -45,25 +45,25 @@ eff_acceptance_ratio(chain::TransformedMCMCIterator) = nsamples(chain) / chain.s

#ctor
function TransformedMCMCIterator(

Check warning on line 47 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L47

Added line #L47 was not covered by tests
rng::AbstractRNG,
algorithm::TransformedMCMCSampling,
target,
id::Integer,
v_init::AbstractVector{<:Real}
v_init::AbstractVector{<:Real},
context::BATContext
)
TransformedMCMCIterator(rng, algorithm, target, Int32(id), v_init)
TransformedMCMCIterator(algorithm, target, Int32(id), v_init, context)

Check warning on line 54 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L54

Added line #L54 was not covered by tests
end


#ctor
function TransformedMCMCIterator(

Check warning on line 59 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L59

Added line #L59 was not covered by tests
rng::AbstractRNG,
algorithm::TransformedMCMCSampling,
target,
id::Int32,
v_init::AbstractVector{<:Real},
context::BATContext,
)
rngpart_cycle = RNGPartition(rng, 0:(typemax(Int16) - 2))
rngpart_cycle = RNGPartition(get_rng(context), 0:(typemax(Int16) - 2))

Check warning on line 66 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L66

Added line #L66 was not covered by tests

μ = target
proposal = algorithm.proposal
Expand All @@ -72,7 +72,7 @@ function TransformedMCMCIterator(
n_accepted = 0

Check warning on line 72 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L68-L72

Added lines #L68 - L72 were not covered by tests

adaptive_transform_spec = algorithm.adaptive_transform
g = init_adaptive_transform(rng, adaptive_transform_spec, μ)
g = init_adaptive_transform(adaptive_transform_spec, μ, context)

Check warning on line 75 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L74-L75

Added lines #L74 - L75 were not covered by tests

logd_x = logdensityof(μ, v_init)
sample_x = DensitySample(v_init, logd_x, 1, TransformedMCMCTransformedSampleID(id, 1, 0), nothing) # TODO
Expand All @@ -84,7 +84,6 @@ function TransformedMCMCIterator(
samples = DensitySampleVector(([sample_x.v], [sample_x.logd], [sample_x.weight], [sample_x.info], [sample_x.aux] ))

Check warning on line 84 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L84

Added line #L84 was not covered by tests

iter = TransformedMCMCIterator(

Check warning on line 86 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L86

Added line #L86 was not covered by tests
rng,
rngpart_cycle,
target,
g,
Expand All @@ -93,7 +92,8 @@ function TransformedMCMCIterator(
sample_z,
stepno,
n_accepted,
TransformedMCMCIteratorInfo(id, cycle, false, false)
TransformedMCMCIteratorInfo(id, cycle, false, false),
context
)


Expand All @@ -109,9 +109,10 @@ end


function propose_mcmc(

Check warning on line 111 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L111

Added line #L111 was not covered by tests
iter::TransformedMCMCIterator{<:Any, <:Any, <:Any, <:Any, <:TransformedMHProposal}
iter::TransformedMCMCIterator{<:Any, <:Any, <:Any, <:TransformedMHProposal}
)
@unpack rng, μ, f_transform, proposal, samples, sample_z, stepno = iter
@unpack μ, f_transform, proposal, samples, sample_z, stepno, context = iter
rng = get_rng(context)
sample_x = last(samples)
x, logd_x = sample_x.v, sample_x.logd
z, logd_z = sample_z.v, sample_z.logd

Check warning on line 118 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L114-L118

Added lines #L114 - L118 were not covered by tests
Expand Down Expand Up @@ -157,7 +158,8 @@ function transformed_mcmc_step!!(
tuner::TransformedAbstractMCMCTunerInstance,
tempering::TransformedMCMCTemperingInstance,
)
@unpack rng, μ, f_transform, proposal, samples, sample_z, stepno = iter
@unpack μ, f_transform, proposal, samples, sample_z, stepno, context = iter
rng = get_rng(context)
sample_x = last(samples)
x, logd_x = sample_x.v, sample_x.logd
z, logd_z = sample_z.v, sample_z.logd
Expand All @@ -168,7 +170,7 @@ function transformed_mcmc_step!!(
z_proposed, logd_z_proposed = sample_z_proposed.v, sample_z_proposed.logd
x_proposed, logd_x_proposed = sample_x_proposed.v, sample_x_proposed.logd

Check warning on line 171 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L170-L171

Added lines #L170 - L171 were not covered by tests

tuner_new, f_transform = tune_mcmc_transform!!(rng, tuner, f_transform, p_accept, z_proposed, z, stepno)
tuner_new, f_transform = tune_mcmc_transform!!(tuner, f_transform, p_accept, z_proposed, z, stepno, context)

Check warning on line 173 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L173

Added line #L173 was not covered by tests

accepted = rand(rng) <= p_accept

Check warning on line 175 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L175

Added line #L175 was not covered by tests

Expand All @@ -193,11 +195,11 @@ function transformed_mcmc_step!!(

f_new = f_transform

Check warning on line 196 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L196

Added line #L196 was not covered by tests

# iter_new = TransformedMCMCIterator(rng, μ_new, f_new, proposal, samples_new, sample_z_new, stepno, n_accepted+Int(accepted))
iter.rng = rng
# iter_new = TransformedMCMCIterator(μ_new, f_new, proposal, samples_new, sample_z_new, stepno, n_accepted+Int(accepted), context)
iter.μ, iter.f_transform, iter.samples, iter.sample_z = μ_new, f_new, samples_new, sample_z_new
iter.n_accepted += Int(accepted)
iter.stepno += 1
@assert iter.context === context

Check warning on line 202 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L199-L202

Added lines #L199 - L202 were not covered by tests

return (iter, tuner_new, tempering_new)

Check warning on line 204 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L204

Added line #L204 was not covered by tests
end
Expand Down Expand Up @@ -288,6 +290,8 @@ function transformed_mcmc_iterate!(
end


#=
# Unused?
function reset_chain(
rng::AbstractRNG,
chain::TransformedMCMCIterator,
Expand All @@ -296,15 +300,18 @@ function reset_chain(
#TODO reset cycle count?
chain.rngpart_cycle = rngpart_cycle
chain.info = TransformedMCMCIteratorInfo(chain.info, cycle=0)
chain.context = set_rng(chain.context, rng)
# wants a next_cycle!
# reset_rng_counters!(chain)
end
=#


function reset_rng_counters!(chain::TransformedMCMCIterator)
set_rng!(chain.rng, chain.rngpart_cycle, chain.info.cycle)
rngpart_step = RNGPartition(chain.rng, 0:(typemax(Int32) - 2))
set_rng!(chain.rng, rngpart_step, chain.stepno)
rng = get_rng(get_context(chain))
set_rng!(rng, chain.rngpart_cycle, chain.info.cycle)
rngpart_step = RNGPartition(rng, 0:(typemax(Int32) - 2))
set_rng!(rng, rngpart_step, chain.stepno)
nothing

Check warning on line 315 in src/samplers/transformed_mcmc/mcmc_iterate.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_iterate.jl#L310-L315

Added lines #L310 - L315 were not covered by tests
end

Expand Down
8 changes: 4 additions & 4 deletions src/samplers/transformed_mcmc/mcmc_sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,22 @@ bat_default(::Type{TransformedMCMCDispatch}, ::Val{:burnin}, trafo::AbstractTran


function bat_sample_impl(

Check warning on line 53 in src/samplers/transformed_mcmc/mcmc_sample.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_sample.jl#L53

Added line #L53 was not covered by tests
rng::AbstractRNG,
target::AnyMeasureOrDensity,
algorithm::TransformedMCMCSampling
algorithm::TransformedMCMCSampling,
context::BATContext
)
density_notrafo = convert(AbstractMeasureOrDensity, target)
density, trafo = transform_and_unshape(algorithm.pre_transform, density_notrafo)

Check warning on line 59 in src/samplers/transformed_mcmc/mcmc_sample.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_sample.jl#L58-L59

Added lines #L58 - L59 were not covered by tests

init = mcmc_init!(

Check warning on line 61 in src/samplers/transformed_mcmc/mcmc_sample.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_sample.jl#L61

Added line #L61 was not covered by tests
rng,
algorithm,
density,
algorithm.nchains,
apply_trafo_to_init(trafo, algorithm.init),
algorithm.tuning_alg,
algorithm.nonzero_weights,
algorithm.store_burnin ? algorithm.callback : nop_func
algorithm.store_burnin ? algorithm.callback : nop_func,
context
)

@unpack chains, tuners, temperers = init

Check warning on line 72 in src/samplers/transformed_mcmc/mcmc_sample.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_sample.jl#L72

Added line #L72 was not covered by tests
Expand Down
4 changes: 2 additions & 2 deletions src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ end


function tune_mcmc_transform!!(

Check warning on line 34 in src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl#L34

Added line #L34 was not covered by tests
rng::AbstractRNG,
tuner::TransformedMCMCNoOpTuner,
transform,
p_accept::Real,
z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead
z_current::Vector{<:Float64},
stepno::Int
stepno::Int,
context::BATContext
)
return (tuner, transform)

Check warning on line 43 in src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl#L43

Added line #L43 was not covered by tests

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ tuning_callback(::TransformedProposalCovTuner) = nop_func

# this function is called in each mcmc_iterate step during tuning
function tune_mcmc_transform!!(

Check warning on line 131 in src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl#L131

Added line #L131 was not covered by tests
rng::AbstractRNG,
tuner::TransformedProposalCovTuner,
transform::Any, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}},
p_accept::Real,
z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead
z_current::Vector{<:Float64},
stepno::Int
stepno::Int,
context::BATContext
)

return (tuner, transform)

Check warning on line 141 in src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl#L141

Added line #L141 was not covered by tests
Expand Down
Loading

0 comments on commit 9cb2625

Please sign in to comment.