Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ReactiveMP.sampletype #229

Merged
merged 18 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,48 @@ Converts (if possible) the elements of the `container` to be of type `E`.
convert_eltype(::Type{E}, container::AbstractArray) where {E} = convert(AbstractArray{E}, container)
convert_eltype(::Type{E}, number::Number) where {E} = convert(E, number)

"""
sampletype(distribution)

Returns a type of the distribution. By default fallbacks to the `eltype`.

See also: [`ReactiveMP.samplefloattype`](@ref), [`ReactiveMP.promote_sampletype`](@ref), [`ReactiveMP.promotesamplefloatype`](@ref)
"""
sampletype(distribution) = eltype(distribution)

sampletype(distribution::Distribution) = sampletype(variate_form(distribution), distribution)
sampletype(::Type{Univariate}, distribution) = eltype(distribution)
sampletype(::Type{Multivariate}, distribution) = Vector{eltype(distribution)}
sampletype(::Type{Matrixvariate}, distribution) = Matrix{eltype(distribution)}

"""
samplefloattype(distribution)

Returns a type of the distribution or the underlying float type in case if sample is `Multivariate` or `Matrixvariate`.
By default fallbacks to the `deep_eltype(sampletype(distribution))`.

See also: [`ReactiveMP.sampletype`](@ref), [`ReactiveMP.promote_sampletype`](@ref), [`ReactiveMP.promote_samplefloatype`](@ref)
"""
samplefloattype(distribution) = deep_eltype(sampletype(distribution))

"""
promote_sampletype(distributions...)

Promotes `sampletype` of the `distributions` to a single type. See also `promote_type`.

See also: [`ReactiveMP.sampletype`](@ref), [`ReactiveMP.samplefloattype`](@ref), [`ReactiveMP.promote_samplefloattype`](@ref)
"""
promote_sampletype(distributions...) = promote_type(sampletype.(distributions)...)

"""
promote_samplefloattype(distributions...)

Promotes `samplefloattype` of the `distributions` to a single type. See also `promote_type`.

See also: [`ReactiveMP.sampletype`](@ref), [`ReactiveMP.samplefloattype`](@ref), [`ReactiveMP.promote_sampletype`](@ref)
"""
promote_samplefloattype(distributions...) = promote_type(samplefloattype.(distributions)...)

"""
logpdf_sample_friendly(distribution)

Expand Down
2 changes: 1 addition & 1 deletion src/distributions/beta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ prod_analytical_rule(::Type{<:Beta}, ::Type{<:Beta}) = ProdAnalyticalRuleAvailab
function prod(::ProdAnalytical, left::Beta, right::Beta)
left_a, left_b = params(left)
right_a, right_b = params(right)
T = promote_type(eltype(left), eltype(right))
T = promote_samplefloattype(left, right)
return Beta(left_a + right_a - one(T), left_b + right_b - one(T))
end

Expand Down
6 changes: 3 additions & 3 deletions src/distributions/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ vague(::Type{<:GammaShapeScale}) = GammaShapeScale(1.0, huge)
prod_analytical_rule(::Type{<:GammaShapeScale}, ::Type{<:GammaShapeScale}) = ProdAnalyticalRuleAvailable()

function prod(::ProdAnalytical, left::GammaShapeScale, right::GammaShapeScale)
T = promote_type(eltype(left), eltype(right))
T = promote_samplefloattype(left, right)
return GammaShapeScale(shape(left) + shape(right) - one(T), (scale(left) * scale(right)) / (scale(left) + scale(right)))
end

Expand Down Expand Up @@ -59,12 +59,12 @@ prod_analytical_rule(::Type{<:GammaShapeRate}, ::Type{<:GammaShapeScale}) = Prod
prod_analytical_rule(::Type{<:GammaShapeScale}, ::Type{<:GammaShapeRate}) = ProdAnalyticalRuleAvailable()

function prod(::ProdAnalytical, left::GammaShapeRate, right::GammaShapeScale)
T = promote_type(eltype(left), eltype(right))
T = promote_samplefloattype(left, right)
return GammaShapeRate(shape(left) + shape(right) - one(T), rate(left) + rate(right))
end

function prod(::ProdAnalytical, left::GammaShapeScale, right::GammaShapeRate)
T = promote_type(eltype(left), eltype(right))
T = promote_samplefloattype(left, right)
return GammaShapeScale(shape(left) + shape(right) - one(T), (scale(left) * scale(right)) / (scale(left) + scale(right)))
end

Expand Down
15 changes: 14 additions & 1 deletion src/distributions/gamma_shape_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export GammaShapeRate
import Distributions: Gamma, shape, rate, logpdf
import SpecialFunctions: loggamma, digamma, gamma
import StatsFuns: log2π
import Random: rand

import Base

Expand Down Expand Up @@ -58,9 +59,21 @@ vague(::Type{<:GammaShapeRate}) = GammaShapeRate(1.0, tiny)
prod_analytical_rule(::Type{<:GammaShapeRate}, ::Type{<:GammaShapeRate}) = ProdAnalyticalRuleAvailable()

function prod(::ProdAnalytical, left::GammaShapeRate, right::GammaShapeRate)
T = promote_type(eltype(left), eltype(right))
T = promote_samplefloattype(left, right)
return GammaShapeRate(shape(left) + shape(right) - one(T), rate(left) + rate(right))
end

Distributions.pdf(dist::GammaShapeRate, x::Real) = (rate(dist)^shape(dist)) / gamma(shape(dist)) * x^(shape(dist) - 1) * exp(-rate(dist) * x)
Distributions.logpdf(dist::GammaShapeRate, x::Real) = shape(dist) * log(rate(dist)) - loggamma(shape(dist)) + (shape(dist) - 1) * log(x) - rate(dist) * x

function Random.rand(rng::AbstractRNG, dist::GammaShapeRate)
return convert(eltype(dist), rand(rng, convert(GammaShapeScale, dist)))
end

function Random.rand(rng::AbstractRNG, dist::GammaShapeRate, n::Integer)
return convert(AbstractArray{eltype(dist)}, rand(rng, convert(GammaShapeScale, dist), n))
end

function Random.rand!(rng::AbstractRNG, dist::GammaShapeRate, container::AbstractVector)
return rand!(rng, convert(GammaShapeScale, dist), container)
end
2 changes: 1 addition & 1 deletion src/distributions/matrix_dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ mean(::typeof(log), dist::MatrixDirichlet) = digamma.(dist.a) .- digamma.(sum(di
prod_analytical_rule(::Type{<:MatrixDirichlet}, ::Type{<:MatrixDirichlet}) = ProdAnalyticalRuleAvailable()

function prod(::ProdAnalytical, left::MatrixDirichlet, right::MatrixDirichlet)
T = promote_type(eltype(left), eltype(right))
T = promote_samplefloattype(left, right)
return MatrixDirichlet(left.a + right.a .- one(T))
end
4 changes: 2 additions & 2 deletions src/distributions/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ end

function Random.rand(rng::AbstractRNG, dist::UnivariateNormalDistributionsFamily{T}) where {T}
μ, σ = mean_std(dist)
return μ + σ * randn(rng, float(T))
return μ + σ * randn(rng, T)
end

function Random.rand(rng::AbstractRNG, dist::UnivariateNormalDistributionsFamily{T}, size::Int64) where {T}
Expand All @@ -359,7 +359,7 @@ end

function Random.rand(rng::AbstractRNG, dist::MultivariateNormalDistributionsFamily{T}) where {T}
μ, L = mean_std(dist)
return μ + L * randn(rng, length(μ))
return μ + L * randn(rng, T, length(μ))
end

function Random.rand(rng::AbstractRNG, dist::MultivariateNormalDistributionsFamily{T}, size::Int64) where {T}
Expand Down
2 changes: 2 additions & 0 deletions src/distributions/pointmass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ variate_form(::PointMass{M}) where {T, M <: AbstractMatrix{T}} = Matrixvariate

##

sampletype(distribution::PointMass{T}) where {T} = T

getpointmass(distribution::PointMass) = distribution.point

##
Expand Down
16 changes: 16 additions & 0 deletions src/distributions/sample_list.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export SampleList, SampleListMeta
import Base: show, ndims, length, size, precision, getindex, broadcasted, map
import Distributions: mean, var, cov, std
import StatsBase: Weights
import Random: rand

using StaticArrays
using LoopVectorization
Expand Down Expand Up @@ -107,6 +108,8 @@ const DEFAULT_SAMPLE_LIST_N_SAMPLES = 5000

Base.eltype(::Type{<:SampleList{D, S, W}}) where {D, S, W} = Tuple{sample_list_eltype(SampleList, D, S), eltype(W)}

sampletype(::SampleList{D, S}) where {D, S} = sample_list_eltype(SampleList, D, S)

sample_list_eltype(::Type{SampleList}, ndims::Tuple{}, ::Type{S}) where {S} = eltype(S)
sample_list_eltype(::Type{SampleList}, ndims::Tuple{Int}, ::Type{S}) where {S} = SVector{ndims[1], eltype(S)}
sample_list_eltype(::Type{SampleList}, ndims::Tuple{Int, Int}, ::Type{S}) where {S} = SMatrix{ndims[1], ndims[2], eltype(S), ndims[1] * ndims[2]}
Expand Down Expand Up @@ -218,6 +221,19 @@ vague(::Type{SampleList}, dim1::Int, dim2::Int; nsamples::Int = DEFAULT_SAMPLE_L

##

rand(samplelist::SampleList) = rand(Random.GLOBAL_RNG, samplelist)
rand(samplelist::SampleList, n::Integer) = rand(Random.GLOBAL_RNG, samplelist, n)

function rand(rng::AbstractRNG, samplelist::SampleList)
return rand(rng, get_samples(samplelist))
end

function rand(rng::AbstractRNG, samplelist::SampleList, n::Integer)
return rand(rng, get_samples(samplelist), n)
end

##

sample_list_default_prod_strategy() = BootstrapImportanceSampling()

## prod related stuff
Expand Down
4 changes: 2 additions & 2 deletions src/nodes/mv_normal_mean_covariance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
m_out, v_out = mean_cov(q_out)
inv_m_Σ = mean(cholinv, q_Σ)

result = zero(promote_type(eltype(m_mean), eltype(m_out), eltype(inv_m_Σ)))
result = zero(promote_samplefloattype(q_out, q_μ, q_Σ))
result += mean(logdet, q_Σ)
result += dim * log2π
@inbounds for k1 in 1:dim, k2 in 1:dim # optimize trace operation (indices can be interchanges because of symmetry)
Expand All @@ -29,7 +29,7 @@ end
m, V = mean_cov(q_out_μ)
inv_m_Σ = mean(cholinv, q_Σ)

result = zero(promote_type(eltype(m), eltype(inv_m_Σ)))
result = zero(promote_samplefloattype(q_out_μ, q_Σ))
result += mean(logdet, q_Σ)
result += dim * log2π
@inbounds for k1 in 1:dim, k2 in 1:dim # optimize trace operation (indices can be interchanges because of symmetry)
Expand Down
11 changes: 6 additions & 5 deletions src/nodes/mv_normal_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import StatsFuns: log2π
m_out, v_out = mean_cov(q_out)
m_Λ = mean(q_Λ)

result = zero(promote_type(eltype(m_mean), eltype(m_out), eltype(m_Λ)))
result = zero(promote_samplefloattype(q_out, q_μ, q_Λ))
result += dim * log2π
result -= mean(logdet, q_Λ)
@inbounds for k1 in 1:dim, k2 in 1:dim
Expand All @@ -34,7 +34,7 @@ end
m_out, v_out = mean_cov(q_out)
df_Λ, S_Λ = params(q_Λ) # prevent allocation of mean matrix

T = promote_type(eltype(m_mean), eltype(m_out), eltype(S_Λ))
T = promote_type(samplefloattype(q_out), samplefloattype(q_μ), typeof(df_Λ), eltype(S_Λ))
result = zero(T)

@inbounds for k1 in 1:dim, k2 in 1:dim
Expand All @@ -58,7 +58,7 @@ end
m, V = mean_cov(q_out_μ)
m_Λ = mean(q_Λ)

T = promote_type(eltype(m), eltype(m_Λ))
T = promote_samplefloattype(q_out_μ, q_Λ)

result = zero(T)
result += dim * convert(T, log2π)
Expand All @@ -80,14 +80,15 @@ end
m, V = mean_cov(q_out_μ)
df_Λ, S_Λ = params(q_Λ) # prevent allocation of mean matrix

result = zero(promote_type(eltype(m), eltype(S_Λ)))
T = promote_type(samplefloattype(q_out_μ), typeof(df_Λ), eltype(S_Λ))
result = zero(T)

@inbounds for k1 in 1:dim, k2 in 1:dim
# optimize trace operation (indices can be interchanges because of symmetry)
result += S_Λ[k1, k2] * (V[k1, k2] + V[dim + k1, dim + k2] - V[dim + k1, k2] - V[k1, dim + k2] + (m[k1] - m[dim + k1]) * (m[k2] - m[dim + k2]))
end
result *= df_Λ
result += dim * log2π
result += dim * convert(T, log2π)
result -= mean(logdet, q_Λ)
result /= 2

Expand Down
4 changes: 2 additions & 2 deletions src/nodes/mv_normal_mean_scale_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
m_out, v_out = mean_cov(q_out)
m_Λ = mean(q_γ) * diageye(dim)

result = zero(promote_type(eltype(m_mean), eltype(m_out), eltype(m_Λ)))
result = zero(promote_samplefloattype(q_out, q_μ, q_γ))
result += dim * log2π
result -= dim * mean(log, q_γ)
@inbounds for k1 in 1:dim, k2 in 1:dim
Expand All @@ -37,7 +37,7 @@ end
m, V = mean_cov(q_out_μ)
m_Λ = mean(q_γ) * diageye(dim)

result = zero(promote_type(eltype(m), eltype(m_Λ)))
result = zero(promote_samplefloattype(q_out_μ, q_γ))
result += dim * log2π
result -= dim * mean(log, q_γ)
@inbounds for k1 in 1:dim, k2 in 1:dim
Expand Down
2 changes: 1 addition & 1 deletion src/rules/addition/in1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ end
ξout, wout = weightedmean_precision(m_out)
ξin1 = wout * mean(m_in2)
ξin1 .-= ξout
T = promote_type(T1, eltype(m_in2))
T = promote_type(T1, samplefloattype(m_in2))
ξin1 .*= -one(T)
return MvNormalWeightedMeanPrecision(ξin1, wout)
end
Expand Down
2 changes: 1 addition & 1 deletion src/rules/addition/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ end
xi_in1, W_in1 = weightedmean_precision(m_in1)
xi_in2, W_in2 = weightedmean_precision(m_in2)

T = promote_type(eltype(W_out), eltype(W_in1), eltype(W_in2))
T = promote_samplefloattype(m_out, m_in1, m_in2)
d = length(xi_out)
Λ = Matrix{T}(undef, (2 * d, 2 * d))
@inbounds for k2 in 1:d
Expand Down
2 changes: 1 addition & 1 deletion src/rules/bifm/in.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# Actual return type depends on meta object as well, so we explicitly cast the result here
# Should be noop if type matches
T = promote_type(eltype(m_out), eltype(m_zprev), eltype(m_znext))
T = promote_samplefloattype(m_out, m_zprev, m_znext)

# return input marginal
return ProdFinal(convert(MvNormalMeanCovariance{T}, MvNormalMeanCovariance(μ_in, Σ_in)))
Expand Down
2 changes: 1 addition & 1 deletion src/rules/bifm/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

# Actual return type depends on meta object as well, so we explicitly cast the result here
# Should be noop if type matches
T = promote_type(eltype(m_out), eltype(m_in), eltype(m_zprev), eltype(m_znext))
T = promote_samplefloattype(m_out, m_in, m_zprev, m_znext)

# return joint marginal
left = convert(MvNormalWeightedMeanPrecision{T}, MvNormalWeightedMeanPrecision(ξ3, Λ3))
Expand Down
2 changes: 1 addition & 1 deletion src/rules/bifm/out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

# Actual return type depends on meta object as well, so we explicitly cast the result here
# Should be noop if type matches
T = promote_type(eltype(m_in), eltype(m_zprev), eltype(m_znext))
T = promote_samplefloattype(m_in, m_zprev, m_znext)

# return outgoing marginal
return ProdFinal(convert(MvNormalMeanCovariance{T}, MvNormalMeanCovariance(μ_out, Σ_out)))
Expand Down
2 changes: 1 addition & 1 deletion src/rules/bifm/znext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# Actual return type depends on meta object as well, so we explicitly cast the result here
# Should be noop if type matches
T = promote_type(eltype(m_out), eltype(m_in), eltype(m_zprev))
T = promote_samplefloattype(m_out, m_in, m_zprev)

# return outgoing marginal
return ProdFinal(convert(MvNormalMeanCovariance{T}, MvNormalMeanCovariance(μ_znext, Σ_znext)))
Expand Down
2 changes: 1 addition & 1 deletion src/rules/bifm/zprev.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

# Actual return type depends on meta object as well, so we explicitly cast the result here
# Should be noop if type matches
T = promote_type(eltype(m_out), eltype(m_in), eltype(m_znext))
T = promote_samplefloattype(m_out, m_in, m_znext)

# return message
return convert(MvNormalWeightedMeanPrecision{T}, MvNormalWeightedMeanPrecision(ξ_zprev, Λ_zprev))
Expand Down
4 changes: 2 additions & 2 deletions src/rules/mv_normal_mean_covariance/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ end

xi = [xi_out; xi_m]

T = promote_type(eltype(W_bar), eltype(W_out), eltype(W_m))
T = promote_samplefloattype(m_out, m_μ, m_Σ)
d = length(xi_out)
W = Matrix{T}(undef, (2 * d, 2 * d))
@inbounds for k2 in 1:d
Expand Down Expand Up @@ -47,7 +47,7 @@ end

xi = [xi_out; xi_m]

T = promote_type(eltype(W_bar), eltype(W_out), eltype(W_m))
T = promote_samplefloattype(m_out, m_μ, q_Σ)
d = length(xi_out)
W = Matrix{T}(undef, (2 * d, 2 * d))
@inbounds for k2 in 1:d
Expand Down
4 changes: 2 additions & 2 deletions src/rules/mv_normal_mean_precision/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end

W_bar = mean(m_Λ)

T = promote_type(eltype(W_bar), eltype(W_y), eltype(W_m))
T = promote_samplefloattype(m_out, m_μ, m_Λ)
d = length(xi_y)
Λ = Matrix{T}(undef, (2 * d, 2 * d))
@inbounds for k2 in 1:d
Expand Down Expand Up @@ -45,7 +45,7 @@ end

W_bar = mean(q_Λ)

T = promote_type(eltype(W_bar), eltype(W_y), eltype(W_m))
T = promote_samplefloattype(m_out, m_μ, q_Λ)
d = length(xi_y)
Λ = Matrix{T}(undef, (2 * d, 2 * d))
@inbounds for k2 in 1:d
Expand Down
6 changes: 3 additions & 3 deletions src/rules/mv_normal_mean_scale_precision/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ end
xi_y, W_y = weightedmean_precision(m_out)
xi_m, W_m = weightedmean_precision(m_μ)

W_bar = mean(m_γ) * diageye(eltype(m_out), ndims(m_out))
W_bar = mean(m_γ) * diageye(samplefloattype(m_out), ndims(m_out))

T = promote_type(eltype(W_bar), eltype(W_y), eltype(W_m))
T = promote_samplefloattype(m_out, m_μ, m_γ)
d = length(xi_y)
Λ = Matrix{T}(undef, (2 * d, 2 * d))
@inbounds for k2 in 1:d
Expand All @@ -41,7 +41,7 @@ end

W_bar = mean(q_γ) * diageye(eltype(m_out), ndims(m_out))

T = promote_type(eltype(W_bar), eltype(W_y), eltype(W_m))
T = promote_samplefloattype(m_out, m_μ, q_γ)
d = length(xi_y)
Λ = Matrix{T}(undef, (2 * d, 2 * d))
@inbounds for k2 in 1:d
Expand Down
2 changes: 1 addition & 1 deletion src/rules/mv_normal_mean_scale_precision/mean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@

@rule MvNormalMeanScalePrecision(:μ, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, q_γ::Any) = begin
m_out_mean, m_out_cov = mean_cov(m_out)
return MvNormalMeanCovariance(m_out_mean, m_out_cov + inv(mean(q_γ)) * diageye(eltype(m_out), ndims(m_out)))
return MvNormalMeanCovariance(m_out_mean, m_out_cov + inv(mean(q_γ)) * diageye(samplefloattype(m_out), ndims(m_out)))
end
Loading