Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename transformed (Update for Measure terminology) #441

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
c1cec60
use Tranformed prefix for all structs
Jun 30, 2023
9cbacf2
use Transformed prefix for all types
Jun 30, 2023
6cbe392
update deps
Jun 30, 2023
e32abca
add AdaptiveTransform
Jun 30, 2023
50c02ba
include transformed_mcmc
Jun 30, 2023
b0794f0
Revert "use Transformed prefix for all types"
Cornelius-G Jun 30, 2023
e1d7aba
chnage replace_type script
Cornelius-G Jun 30, 2023
66fa0ad
use Transformed prefix for all abstract types
Jun 30, 2023
6c39fee
new TransformedMCMCSampling and old MCMCSampling now both working
Cornelius-G Jun 30, 2023
f640fbc
move example
Cornelius-G Jul 3, 2023
d0ce010
use full matrix instead of lower cholesky in AdaptiveMHTuner
Cornelius-G Jul 3, 2023
972d967
use cholesky lower for AdaptiveMHTuning
Cornelius-G Jul 4, 2023
756827f
Merge commit 'c9d7fd98bf61d05fedc0be1643b41d7aabe43c98' into RenameTr…
oschulz Jul 6, 2023
9cb2625
Adapt transformed MCMC code to using BATContext
oschulz Jul 6, 2023
bb97b3a
FIx TransformedAdaptiveMHTuning and example
oschulz Jul 6, 2023
b2d5449
Merge branch 'main' into RenameTransformed
oschulz Jul 9, 2023
119c3fe
Fix Project.toml
oschulz Jul 9, 2023
7f45cee
Fix include order in transformed_mcmc
oschulz Jul 9, 2023
9457054
Fix transformed_check_convergence!
oschulz Jul 9, 2023
6552733
Fix transformed bat_sample_impl and mcmc_burnin!
oschulz Jul 9, 2023
8dc1ebf
Merge branch 'main' into trafo-merge
oschulz Jul 9, 2023
0a027e8
Adapt transformed example to API changes
oschulz Jul 9, 2023
85dc22a
RAMTuner properly persist stepno through multi_cycle_burnin
waldie11 Jul 11, 2023
9602973
Merge remote-tracking branch 'origin/main' into RenameTransformed
waldie11 Jul 11, 2023
f8f4b8d
Merge branch 'main' into RenameTransformed
waldie11 Jul 19, 2023
109a5c9
rewrite mcmc_init! for optimized overall runtime
waldie11 Jul 24, 2023
756d0e4
add infrastructure to ease continue of chains
waldie11 Jul 25, 2023
3647b6e
spaces in return of bat_sample_impl
waldie11 Jul 25, 2023
902257d
introduce _bat_sample_continue
waldie11 Jul 25, 2023
8cc015b
ProgressMeter for known infrastructure
waldie11 Jul 25, 2023
41b817d
ahmc evaluates params [NaN,...] in times
waldie11 Jul 25, 2023
0c4a703
clustered init suggestion
waldie11 Jul 27, 2023
d988a84
switch to median in cluster selection
waldie11 Jul 28, 2023
9909eed
forward best cluster
waldie11 Jul 28, 2023
fc93e18
_cluster_selection correct forward of chains&tuner
waldie11 Jul 28, 2023
3372df3
_cluster_selection proper fail criterion
waldie11 Jul 28, 2023
d6c01b3
viable_idxs corrected
waldie11 Aug 2, 2023
a27f521
Merge branch 'main' into RenameTransformed
Cornelius-G Sep 25, 2023
b6565c3
quick fixes
Cornelius-G Sep 25, 2023
921dc8b
Update to measure terminology, fix type inference, adjust compats
Micki-D Jun 3, 2024
427481f
Relax scale parameter in _cluster_selection() for chain pool init
Micki-D Jun 16, 2024
7fd0b91
Move proposaldist.jl to transformed version
Micki-D Jun 17, 2024
cb6e102
Move mcmc_stats.jl to transformed version
Micki-D Jun 20, 2024
17edf2f
Move mcmc_weighting.jl to transformed version
Micki-D Jun 20, 2024
eeed71f
Preliminary move of mcmc_sampleid.jl to transformed version
Micki-D Jun 20, 2024
2ea612b
Move mcmc_sampleid.jl to transformed version
Micki-D Jun 20, 2024
bfa008e
Move chain_pool_init.jl to transformed version
Micki-D Jun 20, 2024
5be3b3f
Merge mcmc_sample.jl and its transformed version
Micki-D Jun 20, 2024
1fa71cc
Merge mcmc_algorithm.jl and its transformed version
Micki-D Jun 20, 2024
924cfdf
Merge multi_cycle_burnin.jl and transformed version
Micki-D Jun 21, 2024
a1033f1
Move mcmc_noop_tuner.jl to transformed version
Micki-D Jun 21, 2024
189c285
Adjust mcmc.jl
Micki-D Jun 21, 2024
1e9ede9
Stash Changes
Micki-D Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"julia.environmentPath": "C:\\Users\\Cornelius\\.julia\\environments\\v1.9"
}
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ ParallelProcessingTools = "8e8a01fc-6193-5ca1-a2f1-20776dae4199"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Expand Down Expand Up @@ -127,7 +128,7 @@ LaTeXStrings = "1"
LinearAlgebra = "1"
MacroTools = "0.5"
Markdown = "1"
MeasureBase = "0.12, 0.13, 0.14"
MeasureBase = "0.14"
Measurements = "2"
NamedArrays = "0.9, 0.10"
NestedSamplers = "0.8"
Expand All @@ -138,6 +139,7 @@ Parameters = "0.12, 0.13"
Plots = "1"
PositiveFactorizations = "0.2"
Printf = "1"
ProgressMeter = "1"
Random = "1"
Random123 = "1.2"
RecipesBase = "0.7, 0.8, 1.0"
Expand Down
67 changes: 67 additions & 0 deletions examples/dev-internal/transformed_example.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using BAT
using BAT.MeasureBase
using AffineMaps
using ChangesOfVariables
using BAT.LinearAlgebra
using BAT.Distributions
using BAT.InverseFunctions
import BAT: TransformedMCMCIterator, TransformedAdaptiveMHTuning, TransformedRAMTuner, TransformedMHProposal, TransformedNoTransformedMCMCTempering, transformed_mcmc_step!!, MCMCSampleID
using Random123, PositiveFactorizations
using AutoDiffOperators
import AdvancedHMC

import BAT: mcmc_iterate!, transformed_mcmc_iterate!, TransformedMCMCSampling

#ENV["JULIA_DEBUG"] = "BAT"

context = BATContext(ad = ADModule(:ForwardDiff))

posterior = BAT.example_posterior()

my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000), context)


density_notrafo = convert(BATMeasure, posterior)
density, trafo = BAT.transform_and_unshape(PriorToGaussian(), density_notrafo, context)

s = cholesky(Positive, BAT._approx_cov(density)).L
f = BAT.CustomTransform(Mul(s))

my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), tuning_alg=TransformedAdaptiveMHTuning(), nchains=4, nsteps=4*100000, adaptive_transform=f), context)

my_samples = my_result.result



using Plots
plot(my_samples)

r_mh = @time BAT.bat_sample_impl(posterior, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true), context)

r_hmc = @time BAT.bat_sample_impl(posterior, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000), context)

plot(bat_sample(posterior).result)

using BAT.Distributions
using BAT.ValueShapes
prior2 = NamedTupleDist(ShapedAsNT,
b = [4.2, 3.3],
a = Exponential(1.0),
c = Normal(1.0,3.0),
d = product_distribution(Weibull.(ones(2),1)),
e = Beta(1.0, 1.0),
f = MvNormal([0.3,-2.9],Matrix([1.7 0.5;0.5 2.3]))
)

posterior.likelihood.density._log_f(rand(posterior.prior))

posterior.likelihood.density._log_f(rand(prior2))

posterior2 = PosteriorDensity(BAT.logfuncdensity(posterior.likelihood.density._log_f), prior2)


@profview r_ram2 = @time BAT.bat_sample_impl(posterior2, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000), context)

@profview r_mh2 = @time BAT.bat_sample_impl(posterior2, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true), context)

r_hmc2 = @time BAT.bat_sample_impl(posterior2, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000), context)
1 change: 1 addition & 0 deletions ext/BATAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ BAT.pkgext(::Val{:AdvancedHMC}) = BAT.PackageExtension{:AdvancedHMC}()
using Random
using DensityInterface
using HeterogeneousComputing, AutoDiffOperators
using BAT.ChangesOfVariables

using BAT: MeasureLike, BATMeasure

Expand Down
11 changes: 11 additions & 0 deletions ext/BATHDF5Ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ _to_flat_array(A::AbstractArray{<:AbstractArray{<:Real}}) = _to_flat_array(Array

const _AnyRealArrayOrArrays = Union{AbstractArray{<:Real},AbstractArray{<:AbstractArray{<:Real}}}


# TODO: MD Discuss, is to handle "nothing" entries in MCMCSampleIDVector objects
function _h5io_write(datastore::H5DataStore, path::AbstractString, data::Vector{Union{Nothing, Int64}})
if any(isnothing.(data))
data_tmp = fill(0, length(data))
else
data_tmp = convert(Vector{Int64}, data)
end
_h5io_write(datastore, path, data_tmp)
end

function _h5io_write(datastore::H5DataStore, path::AbstractString, data::_AnyRealArrayOrArrays)
@nospecialize datastore, path, data
group = _h5io__get_or_create_group(datastore, dirname(path))
Expand Down
76 changes: 76 additions & 0 deletions ext/ahmc_impl/ahmc_sampler_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,55 @@ BAT.bat_default(::Type{MCMCSampling}, ::Val{:burnin}, mcalg::HamiltonianMC, traf
BAT.get_mcmc_tuning(algorithm::HamiltonianMC) = algorithm.tuning



"""
BAT.TransformedHMCProposal

*BAT-internal, not part of stable public API.*
"""
mutable struct TransformedHMCProposal{
HA<:AdvancedHMC.Hamiltonian,
TR<:AdvancedHMC.Transition,
KRNL<:AdvancedHMC.HMCKernel
}<: BAT.TransformedMCMCProposal
hamiltonian::HA
transition::TR
kernel::KRNL
end

function TransformedHMCProposal(algorithm::HamiltonianMC, target::BATMeasure, context::BATContext, v_init::AbstractVector)
adsel = get_adselector(context)
rng = get_rng(context)
f = checked_logdensityof(target)
metric = ahmc_metric(algorithm.metric, v_init)
fg = valgrad_func(f, adsel)

init_hamiltonian = AdvancedHMC.Hamiltonian(metric, f, fg)
hamiltonian, init_transition = AdvancedHMC.sample_init(rng, init_hamiltonian, v_init)
integrator = _ahmc_set_step_size(algorithm.integrator, hamiltonian, v_init)
termination = _ahmc_convert_termination(algorithm.termination, v_init)
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, termination))

# Perform a dummy step to get type-stable transition value:
transition = AdvancedHMC.transition(deepcopy(rng), deepcopy(hamiltonian), deepcopy(kernel), init_transition.z)

TransformedHMCProposal(hamiltonian, transition, kernel)
end

BAT._get_proposal(alg::HamiltonianMC, target::BATMeasure, context::BATContext, v_init::AbstractVector) = TransformedHMCProposal(alg, target, context, v_init)
BAT._get_adaptive_transform(alg::HamiltonianMC) = BAT.default_adaptive_transform(alg)

function MCMCSampleID(iter::TransformedMCMCIterator{<:Any, <:Any, <:Any, <:TransformedHMCProposal})
stat = AdvancedHMC.stat(iter.proposal.transition)

# TODO MD: Handle proposal-dependent tstat (only NUTS has tree_depth):
AHMCSampleID(
iter.info.id, iter.info.cycle, iter.stepno, CURRENT_SAMPLE,
stat.hamiltonian_energy, stat.tree_depth,
stat.numerical_error, stat.step_size
)
end

# MCMCIterator subtype for HamiltonianMC
mutable struct AHMCIterator{
AL<:HamiltonianMC,
Expand Down Expand Up @@ -297,3 +346,30 @@ end


BAT.eff_acceptance_ratio(chain::AHMCIterator) = nsamples(chain) / nsteps(chain)


function BAT.propose_mcmc(
iter::TransformedMCMCIterator{<:Any, <:Any, <:Any, <:TransformedHMCProposal}
)
μ, f_transform, proposal, samples, sample_z, context = iter.μ, iter.f_transform, iter.proposal, iter.samples, iter.sample_z, iter.context
rng = get_rng(context)
sample_x = last(samples)
x, logd_x = sample_x.v, sample_x.logd
z, logd_z = sample_z.v, sample_z.logd

n = size(z, 1)

proposal.transition = AdvancedHMC.transition(rng, proposal.hamiltonian, proposal.kernel, proposal.transition.z)

z_proposed = proposal.transition.z.θ
x_proposed, ladj = ChangesOfVariables.with_logabsdet_jacobian(f_transform, z_proposed)
logd_x_proposed = BAT.checked_logdensityof(μ, x_proposed)
logd_z_proposed = logd_x_proposed + ladj

p_accept = clamp(exp(logd_z_proposed-logd_z), 0, 1)

sample_z_proposed = BAT._rebuild_density_sample(sample_z, z_proposed, logd_z_proposed)
sample_x_proposed = BAT._rebuild_density_sample(sample_x, x_proposed, logd_x_proposed)

return sample_x_proposed, sample_z_proposed, p_accept
end
66 changes: 57 additions & 9 deletions ext/ahmc_impl/ahmc_tuner_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,61 @@ mutable struct AHMCTuner{A<:AdvancedHMC.AbstractAdaptor} <: AbstractMCMCTunerIns
adaptor::A
end

function (tuning::HMCTuningAlgorithm)(chain::MCMCIterator)

function BAT.get_tuner(tuning::HMCTuningAlgorithm, chain::TransformedMCMCIterator)
θ = first(chain.samples).v
adaptor = ahmc_adaptor(tuning, chain.proposal.hamiltonian.metric, chain.proposal.kernel.τ.integrator, θ)
AHMCTuner(tuning.target_acceptance, adaptor)
end


function (tuning::HMCTuningAlgorithm)(chain::TransformedMCMCIterator)
θ = first(chain.samples).v
adaptor = ahmc_adaptor(tuning, chain.hamiltonian.metric, chain.kernel.τ.integrator, θ)
adaptor = ahmc_adaptor(tuning, chain.proposal.hamiltonian.metric, chain.proposal.kernel.τ.integrator, θ)
AHMCTuner(tuning.target_acceptance, adaptor)
end

# function (tuning::HMCTuningAlgorithm)(chain::MCMCIterator)
# θ = first(chain.samples).v
# adaptor = ahmc_adaptor(tuning, chain.hamiltonian.metric, chain.kernel.τ.integrator, θ)
# AHMCTuner(tuning.target_acceptance, adaptor)
# end

function BAT.tuning_init!(tuner::AHMCTuner, chain::MCMCIterator, max_nsteps::Integer)

function BAT.tuning_init!(tuner::AHMCTuner, chain::TransformedMCMCIterator, max_nsteps::Integer)
AdvancedHMC.Adaptation.initialize!(tuner.adaptor, Int(max_nsteps - 1))
nothing
end

BAT.tuning_postinit!(tuner::AHMCTuner, chain::MCMCIterator, samples::DensitySampleVector) = nothing
# function BAT.tuning_init!(tuner::AHMCTuner, chain::MCMCIterator, max_nsteps::Integer)
# AdvancedHMC.Adaptation.initialize!(tuner.adaptor, Int(max_nsteps - 1))
# nothing
# end



BAT.tuning_postinit!(tuner::AHMCTuner, chain::TransformedMCMCIterator, samples::DensitySampleVector) = nothing

# BAT.tuning_postinit!(tuner::AHMCTuner, chain::MCMCIterator, samples::DensitySampleVector) = nothing

function BAT.tuning_reinit!(tuner::AHMCTuner, chain::MCMCIterator, max_nsteps::Integer)


function BAT.tuning_reinit!(tuner::AHMCTuner, chain::TransformedMCMCIterator, max_nsteps::Integer)
AdvancedHMC.Adaptation.initialize!(tuner.adaptor, Int(max_nsteps - 1))
nothing
end

function BAT.tuning_update!(tuner::AHMCTuner, chain::MCMCIterator, samples::DensitySampleVector)

# function BAT.tuning_reinit!(tuner::AHMCTuner, chain::MCMCIterator, max_nsteps::Integer)
# AdvancedHMC.Adaptation.initialize!(tuner.adaptor, Int(max_nsteps - 1))
# nothing
# end

BAT.default_adaptive_transform(algorithm::HamiltonianMC) = BAT.TriangularAffineTransform()
BAT.default_adaptive_transform(tuning::HMCTuningAlgorithm) = BAT.TriangularAffineTransform()


function BAT.tuning_update!(tuner::AHMCTuner, chain::TransformedMCMCIterator, samples::DensitySampleVector)
max_log_posterior = maximum(samples.logd)
accept_ratio = eff_acceptance_ratio(chain)
if accept_ratio >= 0.9 * tuner.target_acceptance
Expand All @@ -38,11 +73,11 @@ function BAT.tuning_update!(tuner::AHMCTuner, chain::MCMCIterator, samples::Dens
nothing
end

function BAT.tuning_finalize!(tuner::AHMCTuner, chain::MCMCIterator)
function BAT.tuning_finalize!(tuner::AHMCTuner, chain::TransformedMCMCIterator)
adaptor = tuner.adaptor
AdvancedHMC.finalize!(adaptor)
chain.hamiltonian = AdvancedHMC.update(chain.hamiltonian, adaptor)
chain.kernel = AdvancedHMC.update(chain.kernel, adaptor)
chain.proposal.hamiltonian = AdvancedHMC.update(chain.proposal.hamiltonian, adaptor)
chain.proposal.kernel = AdvancedHMC.update(chain.proposal.kernel, adaptor)
nothing
end

Expand All @@ -66,3 +101,16 @@ function (callback::AHMCTunerCallback)(::Val{:mcmc_step}, chain::AHMCIterator)

nothing
end

function BAT.tune_mcmc_transform!!(
tuner::AHMCTuner,
transform::Any, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}},
p_accept::Real,
z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead
z_current::Vector{<:Float64},
stepno::Int,
context::BATContext
)

return (tuner, transform, false)
end
1 change: 1 addition & 0 deletions src/BAT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import EmpiricalDistributions
import HypothesisTests
import Measurements
import NamedArrays
import ProgressMeter
import Random123
import Sobol
import StableRNGs
Expand Down
16 changes: 8 additions & 8 deletions src/measures/bat_pushfwd_measure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ MeasureBase.pullback(f, m::BATMeasure) = _bat_pulbck(f, m, KeepRootMeasure())
MeasureBase.pullback(f, m::BATMeasure, volcorr::KeepRootMeasure) = _bat_pulbck(f, m, volcorr)
MeasureBase.pullback(f, m::BATMeasure, volcorr::ChangeRootMeasure) = _bat_pulbck(f, m, volcorr)

_bat_pulbck(f, m::BATMeasure, volcorr::PushFwdStyle) = pushfwd(inverse(f), m, volcorr)
_bat_pulbck(f, m::BATMeasure, volcorr::PushFwdStyle) = MeasureBase.pushfwd(inverse(f), m, volcorr)


# ToDo: remove
Expand All @@ -84,18 +84,18 @@ function DensityInterface.logdensityof(@nospecialize(m::_NonBijectiveBATPusfwdMe
end

function DensityInterface.logdensityof(m::BATPushFwdMeasure{F,I,M,ChangeRootMeasure}, v::Any) where {F,I,M}
v_orig = inverse(m.trafo)(v)
logdensityof(parent(m), v_orig)
v_orig = m.finv(v)
logdensityof(m.origin, v_orig)
end

function checked_logdensityof(m::BATPushFwdMeasure{F,I,M,ChangeRootMeasure}, v::Any) where {F,I,M}
v_orig = inverse(m.trafo)(v)
checked_logdensityof(parent(m), v_orig)
v_orig = m.finv(v)
checked_logdensityof(m.origin, v_orig)
end


function _v_orig_and_ladj(m::BATPushFwdMeasure, v::Any)
with_logabsdet_jacobian(inverse(m.trafo), v)
with_logabsdet_jacobian(m.finv, v)
end

# TODO: Would profit from custom pullback:
Expand Down Expand Up @@ -123,13 +123,13 @@ end

function DensityInterface.logdensityof(m::BATPushFwdMeasure{F,I,M,KeepRootMeasure}, v::Any) where {F,I,M}
v_orig, ladj = _v_orig_and_ladj(m, v)
logd_orig = logdensityof(parent(m), v_orig)
logd_orig = logdensityof(m.origin, v_orig)
_combine_logd_with_ladj(logd_orig, ladj)
end

function checked_logdensityof(m::BATPushFwdMeasure{F,I,M,KeepRootMeasure}, v::Any) where {F,I,M}
v_orig, ladj = _v_orig_and_ladj(m, v)
logd_orig = logdensityof(parent(m), v_orig)
logd_orig = logdensityof(m.origin, v_orig)
isnan(logd_orig) && @throw_logged EvalException(logdensityof, m, v, 0)
_combine_logd_with_ladj(logd_orig, ladj)
end
Expand Down
Loading
Loading