Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with truncated distribution as prior #450

Closed
gipert opened this issue Sep 7, 2024 · 5 comments
Closed

Error with truncated distribution as prior #450

gipert opened this issue Sep 7, 2024 · 5 comments

Comments

@gipert
Copy link

gipert commented Sep 7, 2024

BAT crashes deterministically when attempting to use truncated distributions as priors. If I set, in docs/src/bat_tutorial.jl:

prior = distprod(
    a = [Weibull(1.1, 5000), Weibull(1.1, 5000)],
    mu = [-2.0..0.0, 1.0..3.0],
    sigma = truncated(Normal(0, 2), lower=0)
)

I get:

> julia docs/src/bat_tutorial.jl
[ Info: Setting new default BAT context BATContext{Float64}(Random123.Philox4x{UInt64, 10}(0x8bdbac22c8306763, 0x54639feaadacd6cf, 0x6765c3d759a52d0e, 0x4c6d85b755dcf9f0, 0x1e067ade302b087d, 0xf3403de12719c1b2, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0), HeterogeneousComputing.CPUnit(), BAT._NoADSelected())
[ Info: MCMCChainPoolInit: trying to generate 4 viable MCMC chain(s).
┌ Debug: Generating dummy MCMC chain to determine chain, output and tuner types.
└ @ BAT ~/.julia/packages/BAT/nxtXP/src/samplers/mcmc/chain_pool_init.jl:80
ERROR: LoadError: ArgumentError: Can't derive numeric type for type Nothing
Stacktrace:
  [1] realnumtype(::Type{Nothing})
    @ ValueShapes ~/.julia/packages/ValueShapes/rT1Zi/src/value_shape.jl:22
  [2] map
    @ ./tuple.jl:293 [inlined]
  [3] map
    @ ./tuple.jl:294 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/ValueShapes/rT1Zi/src/value_shape.jl:31 [inlined]
  [5] realnumtype(::Type{Tuple{Float64, Float64, Float64, Nothing}})
    @ ValueShapes ~/.julia/packages/ValueShapes/rT1Zi/src/value_shape.jl:31
  [6] _dist_params_numtype(d::Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:307
  [7] _eval_dist_trafo_func
    @ ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:373 [inlined]
  [8] apply_dist_trafo(trg_d::Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}, ::BAT.StandardUvUniform{Float64}, src_v::Float64)
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:395
  [9] apply_dist_trafo(trg_d::Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}, src_d::BAT.StandardUvNormal{Float64}, src_v::Float64)
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:296
 [10] WithForwardDiff
    @ ~/.julia/packages/ForwardDiffPullbacks/s8kVo/src/with_forwarddiff.jl:22 [inlined]
 [11] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
 [12] _broadcast_getindex
    @ ./broadcast.jl:682 [inlined]
 [13] getindex
    @ ./broadcast.jl:636 [inlined]
 [14] copy
    @ ./broadcast.jl:942 [inlined]
 [15] materialize
    @ ./broadcast.jl:903 [inlined]
 [16] _product_dist_trafo_impl(trg_ds::FillArrays.Fill{Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}, 1, Tuple{Base.OneTo{Int64}}}, src_ds::BAT.StandardUvNormal{Float64}, src_v::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:540
 [17] apply_dist_trafo
    @ ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:565 [inlined]
 [18] _stdmv_to_flat_ntdistelem(td::Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}, src_d::BAT.StandardMvNormal{Float64}, src_v::Vector{Float64}, src_acc::ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:615
 [19] #119
    @ ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:625 [inlined]
 [20] map
    @ ./tuple.jl:319 [inlined]
 [21] map
    @ ./tuple.jl:322 [inlined]
 [22] apply_dist_trafo(trg_d::ValueShapes.UnshapedNTD{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}}, src_d::BAT.StandardMvNormal{Float64}, src_v::Vector{Float64})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:625
 [23] apply_dist_trafo(trg_d::ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}, src_d::BAT.StandardMvNormal{Float64}, src_v::Vector{Float64})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:630
 [24] DistributionTransform
    @ ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:156 [inlined]
 [25] macro expansion
    @ ~/.julia/packages/FunctionChains/piKSk/src/function_chain.jl:0 [inlined]
 [26] FunctionChain
    @ ~/.julia/packages/FunctionChains/piKSk/src/function_chain.jl:161 [inlined]
 [27] logdensityof
    @ ~/.julia/packages/DensityInterface/MCyV6/src/interface.jl:256 [inlined]
 [28] logdensityof(density::PosteriorMeasure{DensityInterface.LogFuncDensity{FunctionChains.FunctionChain{Tuple{BAT.DistributionTransform{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}, BAT.StandardMvNormal{Float64}}, var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}}}, BAT.BATDistMeasure{BAT.StandardMvNormal{Float64}}}, v::Vector{Float64})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/measures/posterior_measure.jl:59
 [29] BAT.MHIterator(algorithm::MetropolisHastings{TDist{Float64}, RepetitionWeighting{Int64}, AdaptiveMHTuning}, target::PosteriorMeasure{DensityInterface.LogFuncDensity{FunctionChains.FunctionChain{Tuple{BAT.DistributionTransform{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}, BAT.StandardMvNormal{Float64}}, var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}}}, BAT.BATDistMeasure{BAT.StandardMvNormal{Float64}}}, info::BAT.MCMCIteratorInfo, x_init::Vector{Float64}, context::BATContext{Float64, Random123.Philox4x{UInt64, 10}, HeterogeneousComputing.CPUnit, BAT._NoADSelected})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/samplers/mcmc/mh/mh_sampler.jl:92
 [30] MCMCIterator
    @ ~/.julia/packages/BAT/nxtXP/src/samplers/mcmc/mh/mh_sampler.jl:135 [inlined]
 [31] mcmc_init!(algorithm::MetropolisHastings{TDist{Float64}, RepetitionWeighting{Int64}, AdaptiveMHTuning}, density::PosteriorMeasure{DensityInterface.LogFuncDensity{FunctionChains.FunctionChain{Tuple{BAT.DistributionTransform{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}, BAT.StandardMvNormal{Float64}}, var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}}}, BAT.BATDistMeasure{BAT.StandardMvNormal{Float64}}}, nchains::Int64, init_alg::MCMCChainPoolInit, tuning_alg::AdaptiveMHTuning, nonzero_weights::Bool, callback::Function, context::BATContext{Float64, Random123.Philox4x{UInt64, 10}, HeterogeneousComputing.CPUnit, BAT._NoADSelected})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/samplers/mcmc/chain_pool_init.jl:85
 [32] bat_sample_impl(m::PosteriorMeasure{DensityInterface.LogFuncDensity{var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}, BAT.BATDistMeasure{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}}}, algorithm::MCMCSampling{MetropolisHastings{TDist{Float64}, RepetitionWeighting{Int64}, AdaptiveMHTuning}, PriorToGaussian, MCMCChainPoolInit, MCMCMultiCycleBurnin, BrooksGelmanConvergence, typeof(BAT.nop_func)}, context::BATContext{Float64, Random123.Philox4x{UInt64, 10}, HeterogeneousComputing.CPUnit, BAT._NoADSelected})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/samplers/mcmc/mcmc_sample.jl:46
 [33] bat_sample(target::PosteriorMeasure{DensityInterface.LogFuncDensity{var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}, BAT.BATDistMeasure{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}}}, algorithm::MCMCSampling{MetropolisHastings{TDist{Float64}, RepetitionWeighting{Int64}, AdaptiveMHTuning}, PriorToGaussian, MCMCChainPoolInit, MCMCMultiCycleBurnin, BrooksGelmanConvergence, typeof(BAT.nop_func)}, context::BATContext{Float64, Random123.Philox4x{UInt64, 10}, HeterogeneousComputing.CPUnit, BAT._NoADSelected})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/algotypes/sampling_algorithm.jl:56
 [34] bat_sample(target::PosteriorMeasure{DensityInterface.LogFuncDensity{var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}, BAT.BATDistMeasure{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}}}, algorithm::MCMCSampling{MetropolisHastings{TDist{Float64}, RepetitionWeighting{Int64}, AdaptiveMHTuning}, PriorToGaussian, MCMCChainPoolInit, MCMCMultiCycleBurnin, BrooksGelmanConvergence, typeof(BAT.nop_func)})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/algotypes/sampling_algorithm.jl:67
 [35] top-level scope
    @ ~/sw/src/BAT.jl/docs/src/bat_tutorial.jl:84
in expression starting at /home/gipert/sw/src/BAT.jl/docs/src/bat_tutorial.jl:84

any workaround?

@oschulz
Copy link
Member

oschulz commented Sep 7, 2024

That's strange, we have used truncated priors a lot - I'll fix this.

@oschulz
Copy link
Member

oschulz commented Sep 11, 2024

Ah, with d = truncated(Normal(0, 2), lower=0) we have d.upper isa Nothing, which ValueShapes.realnumtype currently can't handle.

Just use truncated(Normal(0, 2), 0, Inf) for now, I'll fix this in ValueShapes.

@oschulz
Copy link
Member

oschulz commented Sep 11, 2024

Will be fixed by oschulz/ValueShapes.jl#78

@oschulz
Copy link
Member

oschulz commented Sep 11, 2024

Fixed in ValueShapes v0.11.3 .

@oschulz oschulz closed this as completed Sep 11, 2024
@gipert
Copy link
Author

gipert commented Sep 11, 2024

Thanks Oli!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants