From da44b0509ca0875db3a737f46e6436852b32e995 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Mon, 17 Oct 2022 12:52:22 +0200 Subject: [PATCH 01/48] Add MAR node --- src/ReactiveMP.jl | 1 + src/nodes/autoregressive.jl | 2 +- src/nodes/mv_autoregressive.jl | 124 +++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 src/nodes/mv_autoregressive.jl diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index 5782fe66b..73f61d406 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -126,6 +126,7 @@ include("nodes/gamma_mixture.jl") include("nodes/dot_product.jl") include("nodes/transition.jl") include("nodes/autoregressive.jl") +include("nodes/mv_autoregressive.jl") include("nodes/bifm.jl") include("nodes/bifm_helper.jl") include("nodes/probit.jl") diff --git a/src/nodes/autoregressive.jl b/src/nodes/autoregressive.jl index 217676628..a696dd485 100644 --- a/src/nodes/autoregressive.jl +++ b/src/nodes/autoregressive.jl @@ -1,4 +1,4 @@ -export AR, Autoregressive, ARsafe, ARunsafe, ARMeta +export AR, Autoregressive, ARsafe, ARunsafe, ARMeta, ar_unit, ar_slice import LazyArrays import StatsFuns: log2π diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl new file mode 100644 index 000000000..b7807fac3 --- /dev/null +++ b/src/nodes/mv_autoregressive.jl @@ -0,0 +1,124 @@ +export MAR, MvAutoregressive, MARMeta + +import LazyArrays +import StatsFuns: log2π + +struct MAR end + +const MvAutoregressive = MAR + +struct MARMeta + order::Int +end + +function MARMeta(order) + return MARMeta(order) +end + +getorder(meta::MARMeta) = meta.order + +@node MAR Stochastic [y, x, θ, Λ] + +default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag explicitly specified") + +@average_energy AR ( + q_y_x::MultivariateNormalDistributionsFamily, + q_θ::MultivariateNormalDistributionsFamily, + q_Λ::Wishart, + meta::ARMeta +) = begin + mθ, Vθ = mean_cov(q_θ) + myx, Vyx = mean_cov(q_y_x) + mΛ = mean(q_Λ) + + order = getorder(meta) + + mx, Vx = ar_slice(getvform(meta), myx, (order+1):2order), ar_slice(getvform(meta), Vyx, (order+1):2order, (order+1):2order) + my1, Vy1 = first(myx), first(Vyx) + Vy1x = ar_slice(getvform(meta), Vyx, 1, order+1:2order) + + # Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2 + AE = + ( + -mean(log, q_γ) + log2π + + mγ * ( + Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + + abs2(dot(mθ, mx)) + ) + ) / 2 + + # correction + if getorder(meta) > 1 + AE += entropy(q_y_x) + idc = LazyArrays.Vcat(1, (order+1):2order) + myx_n = view(myx, idc) + Vyx_n = view(Vyx, idc, idc) + q_y_x = MvNormalMeanCovariance(myx_n, Vyx_n) + AE -= entropy(q_y_x) + end + + return AE +end + +# Helpers for AR rules + +## MAllocation-free AR Precision Matrix + +struct MARPrecisionMatrix{T} <: AbstractMatrix{T} + order :: Int + Λ :: T +end + +Base.size(precision::MARPrecisionMatrix) = (precision.order, precision.order) +Base.getindex(precision::MARPrecisionMatrix, i::Int, j::Int) = + (i === 1 && j === 1) ? precision.Λ : ((i === j) ? convert(eltype(precision), huge) : zero(eltype(precision))) + +Base.eltype(::Type{<:MARPrecisionMatrix{T}}) where {T} = T +Base.eltype(::MARPrecisionMatrix{T}) where {T} = T + +add_precision(matrix::AbstractMatrix, precision::MARPrecisionMatrix) = broadcast(+, matrix, precision) + +add_precision!(matrix::AbstractMatrix, precision::MARPrecisionMatrix) = broadcast!(+, matrix, precision) + +# function Base.broadcast!(::typeof(+), matrix::AbstractMatrix, precision::MARPrecisionMatrix) +# matrix[1, 1] += precision.γ +# for j in 2:first(size(matrix)) +# matrix[j, j] += convert(eltype(precision), huge) +# end +# return matrix +# end + +mar_precision(::Type{Multivariate}, order, Λ) = MARPrecisionMatrix(order, Λ) +mar_precision(::Type{Univariate}, order, γ) = γ + +## Allocation-free AR Transition matrix + +struct ARTransitionMatrix{T} <: AbstractMatrix{T} + order::Int + inv_γ::T + + function ARTransitionMatrix(order::Int, γ::T) where {T <: Real} + return new{T}(order, inv(γ)) + end +end + +Base.size(transition::ARTransitionMatrix) = (transition.order, transition.order) +Base.getindex(transition::ARTransitionMatrix, i::Int, j::Int) = + (i === 1 && j === 1) ? transition.inv_γ : zero(eltype(transition)) + +Base.eltype(::Type{<:ARTransitionMatrix{T}}) where {T} = T +Base.eltype(::ARTransitionMatrix{T}) where {T} = T + +add_transition(matrix::AbstractMatrix, transition::ARTransitionMatrix) = broadcast(+, matrix, transition) +add_transition(value::Real, transition::Real) = value + transition + +add_transition!(matrix::AbstractMatrix, transition::ARTransitionMatrix) = broadcast!(+, matrix, transition) +add_transition!(value::Real, transition::Real) = value + transition + +function Base.broadcast!(::typeof(+), matrix::AbstractMatrix, transition::ARTransitionMatrix) + matrix[1] += transition.inv_γ + return matrix +end + +ar_transition(::Type{Multivariate}, order, γ) = ARTransitionMatrix(order, γ) +ar_transition(::Type{Univariate}, order, γ) = inv(γ) From ef8e17ebd9ed745c1c3c8129ce0d9b5222954db5 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Thu, 20 Oct 2022 12:40:32 +0200 Subject: [PATCH 02/48] Add rules prototypes --- src/nodes/mv_autoregressive.jl | 85 +++++++++++++----------- src/rules/mv_autoregressive/helpers.jl | 0 src/rules/mv_autoregressive/lambda.jl | 0 src/rules/mv_autoregressive/marginals.jl | 0 src/rules/mv_autoregressive/theta.jl | 0 src/rules/mv_autoregressive/x.jl | 0 src/rules/mv_autoregressive/y.jl | 27 ++++++++ 7 files changed, 73 insertions(+), 39 deletions(-) create mode 100644 src/rules/mv_autoregressive/helpers.jl create mode 100644 src/rules/mv_autoregressive/lambda.jl create mode 100644 src/rules/mv_autoregressive/marginals.jl create mode 100644 src/rules/mv_autoregressive/theta.jl create mode 100644 src/rules/mv_autoregressive/x.jl create mode 100644 src/rules/mv_autoregressive/y.jl diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index b7807fac3..bcdac9022 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -8,14 +8,19 @@ struct MAR end const MvAutoregressive = MAR struct MARMeta - order::Int + order :: Int # order (lag) of MAR + ds :: Int # dimensionality of MAR process, i.e., the number of correlated AR processes end -function MARMeta(order) - return MARMeta(order) +function MARMeta(order, ds=2) + if ds < 2 + @error "ds parameter should be > 1. Use AR node if ds = 1" + end + return MARMeta(order, ds) end getorder(meta::MARMeta) = meta.order +getdimensionality(meta::MARMeta) = meta.ds @node MAR Stochastic [y, x, θ, Λ] @@ -27,44 +32,45 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl q_Λ::Wishart, meta::ARMeta ) = begin - mθ, Vθ = mean_cov(q_θ) - myx, Vyx = mean_cov(q_y_x) - mΛ = mean(q_Λ) - - order = getorder(meta) - - mx, Vx = ar_slice(getvform(meta), myx, (order+1):2order), ar_slice(getvform(meta), Vyx, (order+1):2order, (order+1):2order) - my1, Vy1 = first(myx), first(Vyx) - Vy1x = ar_slice(getvform(meta), Vyx, 1, order+1:2order) - - # Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2 - AE = - ( - -mean(log, q_γ) + log2π + - mγ * ( - Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + - abs2(dot(mθ, mx)) - ) - ) / 2 - - # correction - if getorder(meta) > 1 - AE += entropy(q_y_x) - idc = LazyArrays.Vcat(1, (order+1):2order) - myx_n = view(myx, idc) - Vyx_n = view(Vyx, idc, idc) - q_y_x = MvNormalMeanCovariance(myx_n, Vyx_n) - AE -= entropy(q_y_x) - end - - return AE + # mθ, Vθ = mean_cov(q_θ) + # myx, Vyx = mean_cov(q_y_x) + # mΛ = mean(q_Λ) + + # order = getorder(meta) + + # mx, Vx = ar_slice(getvform(meta), myx, (order+1):2order), ar_slice(getvform(meta), Vyx, (order+1):2order, (order+1):2order) + # my1, Vy1 = first(myx), first(Vyx) + # Vy1x = ar_slice(getvform(meta), Vyx, 1, order+1:2order) + + # # Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2 + # AE = + # ( + # -mean(log, q_γ) + log2π + + # mγ * ( + # Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + + # abs2(dot(mθ, mx)) + # ) + # ) / 2 + + # # correction + # if getorder(meta) > 1 + # AE += entropy(q_y_x) + # idc = LazyArrays.Vcat(1, (order+1):2order) + # myx_n = view(myx, idc) + # Vyx_n = view(Vyx, idc, idc) + # q_y_x = MvNormalMeanCovariance(myx_n, Vyx_n) + # AE -= entropy(q_y_x) + # end + + # return AE + 0 end # Helpers for AR rules ## MAllocation-free AR Precision Matrix -struct MARPrecisionMatrix{T} <: AbstractMatrix{T} +struct MARPrecisionMatrix{T} <: AbstractMatrix{AbstractMatrix{T}} order :: Int Λ :: T end @@ -93,12 +99,13 @@ mar_precision(::Type{Univariate}, order, γ) = γ ## Allocation-free AR Transition matrix -struct ARTransitionMatrix{T} <: AbstractMatrix{T} +struct MARTransitionMatrix{T} <: AbstractMatrix{AbstractMatrix{T}} order::Int - inv_γ::T + ds ::Int + inv_Λ::T - function ARTransitionMatrix(order::Int, γ::T) where {T <: Real} - return new{T}(order, inv(γ)) + function ARTransitionMatrix(order::Int, ds::Int, inv_Λ::T) where {T <: AbstractMatrix} + return new{T}(order, ds, inv(inv_Λ)) end end diff --git a/src/rules/mv_autoregressive/helpers.jl b/src/rules/mv_autoregressive/helpers.jl new file mode 100644 index 000000000..e69de29bb diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl new file mode 100644 index 000000000..e69de29bb diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl new file mode 100644 index 000000000..e69de29bb diff --git a/src/rules/mv_autoregressive/theta.jl b/src/rules/mv_autoregressive/theta.jl new file mode 100644 index 000000000..e69de29bb diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl new file mode 100644 index 000000000..e69de29bb diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl new file mode 100644 index 000000000..2ad339d7d --- /dev/null +++ b/src/rules/mv_autoregressive/y.jl @@ -0,0 +1,27 @@ + +@rule MAR(:y, Marginalisation) (m_x::NormalDistributionsFamily, q_θ::NormalDistributionsFamily, q_Λ::Any, meta::ARMeta) = +begin + mθ, Vθ = mean_cov(q_θ) + mx, Wx = mean_invcov(m_x) + + mΛ = mean(q_Λ) + + mA = as_companion_matrix(mθ) + mV = ar_transition(getvform(meta), getorder(meta), mγ) + + D = Wx + mγ * Vθ + C = mA * inv(D) + + my = C * Wx * mx + Vy = add_transition!(C * mA', mV) + + return convert(promote_variate_type(getvform(meta), NormalMeanVariance), my, Vy) +end + +@rule AR(:y, Marginalisation) (q_x::Any, q_θ::Any, q_γ::Any, meta::ARMeta) = begin +mA = as_companion_matrix(mean(q_θ)) + +mV = ar_transition(getvform(meta), getorder(meta), mean(q_γ)) + +return convert(promote_variate_type(getvform(meta), NormalMeanVariance), mA * mean(q_x), mV) +end From 35aba3568859d24e5f56bd70a040d71b3421ebe4 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 28 Oct 2022 14:03:18 +0200 Subject: [PATCH 03/48] Add rules for MAR --- src/nodes/mv_autoregressive.jl | 152 ++++++++++------------- src/rules/mv_autoregressive/a.jl | 26 ++++ src/rules/mv_autoregressive/lambda.jl | 37 ++++++ src/rules/mv_autoregressive/marginals.jl | 56 +++++++++ src/rules/mv_autoregressive/theta.jl | 0 src/rules/mv_autoregressive/x.jl | 29 +++++ src/rules/mv_autoregressive/y.jl | 34 ++--- src/rules/prototypes.jl | 6 + 8 files changed, 240 insertions(+), 100 deletions(-) create mode 100644 src/rules/mv_autoregressive/a.jl delete mode 100644 src/rules/mv_autoregressive/theta.jl diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index bcdac9022..c61f0560b 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -1,4 +1,4 @@ -export MAR, MvAutoregressive, MARMeta +export MAR, MvAutoregressive, MARMeta, mar_transition, mar_shift import LazyArrays import StatsFuns: log2π @@ -26,106 +26,92 @@ getdimensionality(meta::MARMeta) = meta.ds default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag explicitly specified") -@average_energy AR ( +@average_energy MAR ( q_y_x::MultivariateNormalDistributionsFamily, - q_θ::MultivariateNormalDistributionsFamily, + q_a::MultivariateNormalDistributionsFamily, q_Λ::Wishart, - meta::ARMeta + meta::MARMeta ) = begin - # mθ, Vθ = mean_cov(q_θ) - # myx, Vyx = mean_cov(q_y_x) - # mΛ = mean(q_Λ) + ma, Va = mean_cov(q_a) + myx, Vyx = mean_cov(q_y_x) + mΛ = mean(q_Λ) - # order = getorder(meta) + order, ds = getorder(meta), getdimensionality(meta) + F = Multivariate + dim = order*ds + n = div(ndims(q_y_x), 2) - # mx, Vx = ar_slice(getvform(meta), myx, (order+1):2order), ar_slice(getvform(meta), Vyx, (order+1):2order, (order+1):2order) - # my1, Vy1 = first(myx), first(Vyx) - # Vy1x = ar_slice(getvform(meta), Vyx, 1, order+1:2order) - # # Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2 - # AE = - # ( - # -mean(log, q_γ) + log2π + - # mγ * ( - # Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + - # abs2(dot(mθ, mx)) - # ) - # ) / 2 - - # # correction - # if getorder(meta) > 1 - # AE += entropy(q_y_x) - # idc = LazyArrays.Vcat(1, (order+1):2order) - # myx_n = view(myx, idc) - # Vyx_n = view(Vyx, idc, idc) - # q_y_x = MvNormalMeanCovariance(myx_n, Vyx_n) - # AE -= entropy(q_y_x) - # end - - # return AE - 0 -end - -# Helpers for AR rules + ma, Va = mean_cov(q_a) + mA = mar_companion_matrix(order, ds, ma)[1:order, 1:dim] -## MAllocation-free AR Precision Matrix + mx, Vx = ar_slice(F, myx, (dim+1):2dim), ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) + my1, Vy1 = myx[1:ds], Vyx[1:ds, 1:ds] + Vy1x = ar_slice(F, Vyx, 1:ds, dim+1:2dim) -struct MARPrecisionMatrix{T} <: AbstractMatrix{AbstractMatrix{T}} - order :: Int - Λ :: T -end - -Base.size(precision::MARPrecisionMatrix) = (precision.order, precision.order) -Base.getindex(precision::MARPrecisionMatrix, i::Int, j::Int) = - (i === 1 && j === 1) ? precision.Λ : ((i === j) ? convert(eltype(precision), huge) : zero(eltype(precision))) + # this should be inside MARMeta + es = [uvector(order, i) for i in 1:order] + Fs = [mask_mar(order, ds, i) for i in 1:order] -Base.eltype(::Type{<:MARPrecisionMatrix{T}}) where {T} = T -Base.eltype(::MARPrecisionMatrix{T}) where {T} = T + # # Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2 + g₁ = my1'*mΛ*my1 + tr(Vy1*mΛ) + g₂ = mx'*mA'*mΛ*my1 + tr(Vy1x*mA'*mΛ) + g₃ = g₂ + G = sum(sum(Fs[i]*(ma*es[i]'*mΛ*es[j]*ma' + Va*es[i]'*mΛ*es[j])*Fs[j]' for i in 1:order) for j in 1:order) + g₄ = mx'*G*mx + tr(Vx*G) + AE = -mean(logdet, q_Λ) + n/2*log2π + 0.5 + g₁ + g₂ + g₃ + g₄ + + if order > 1 + AE += entropy(q_y_x) + @show idc = LazyArrays.Vcat(1:order, (dim+1):2dim) + myx_n = view(myx, idc) + Vyx_n = view(Vyx, idc, idc) + q_y_x = MvNormalMeanCovariance(myx_n, Vyx_n) + AE -= entropy(q_y_x) + end -add_precision(matrix::AbstractMatrix, precision::MARPrecisionMatrix) = broadcast(+, matrix, precision) + return AE +end -add_precision!(matrix::AbstractMatrix, precision::MARPrecisionMatrix) = broadcast!(+, matrix, precision) +# Helpers for AR rules -# function Base.broadcast!(::typeof(+), matrix::AbstractMatrix, precision::MARPrecisionMatrix) -# matrix[1, 1] += precision.γ -# for j in 2:first(size(matrix)) -# matrix[j, j] += convert(eltype(precision), huge) -# end -# return matrix -# end -mar_precision(::Type{Multivariate}, order, Λ) = MARPrecisionMatrix(order, Λ) -mar_precision(::Type{Univariate}, order, γ) = γ +function mask_mar(order, ds, index) + theta_len = order * order * ds + F = zeros(order * ds, theta_len) + F[1:order, order*index-1:order*index] = diageye(order) + F[ds+1:end, order*index+ds+1:order*index+ds+order] = diageye(order) + return F +end -## Allocation-free AR Transition matrix +function mar_transition(order, Λ) + dim = size(Λ, 1) + W = 1e12*diageye(dim*order) + W[1:dim, 1:dim] = Λ + return W +end -struct MARTransitionMatrix{T} <: AbstractMatrix{AbstractMatrix{T}} - order::Int - ds ::Int - inv_Λ::T - function ARTransitionMatrix(order::Int, ds::Int, inv_Λ::T) where {T <: AbstractMatrix} - return new{T}(order, ds, inv(inv_Λ)) +function mar_shift(order, ds) + dim = order*ds + S = diageye(dim) + for i in dim:-1:ds+1 + S[i,:] = S[i-ds, :] end + S[1:ds, :] = zeros(ds, dim) + return S end -Base.size(transition::ARTransitionMatrix) = (transition.order, transition.order) -Base.getindex(transition::ARTransitionMatrix, i::Int, j::Int) = - (i === 1 && j === 1) ? transition.inv_γ : zero(eltype(transition)) - -Base.eltype(::Type{<:ARTransitionMatrix{T}}) where {T} = T -Base.eltype(::ARTransitionMatrix{T}) where {T} = T - -add_transition(matrix::AbstractMatrix, transition::ARTransitionMatrix) = broadcast(+, matrix, transition) -add_transition(value::Real, transition::Real) = value + transition - -add_transition!(matrix::AbstractMatrix, transition::ARTransitionMatrix) = broadcast!(+, matrix, transition) -add_transition!(value::Real, transition::Real) = value + transition - -function Base.broadcast!(::typeof(+), matrix::AbstractMatrix, transition::ARTransitionMatrix) - matrix[1] += transition.inv_γ - return matrix +function uvector(dim, pos=1) + u = zeros(dim) + u[pos] = 1 + return dim == 1 ? u[pos] : u end -ar_transition(::Type{Multivariate}, order, γ) = ARTransitionMatrix(order, γ) -ar_transition(::Type{Univariate}, order, γ) = inv(γ) +function mar_companion_matrix(order, ds, a) + dim = order*ds + S = mar_shift(order, ds) + es = [uvector(dim, i) for i in 1:order] + Fs = [mask_mar(order, ds, i) for i in 1:order] + return S .+ sum(es[i]*a'*Fs[i]' for i in 1:order) +end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl new file mode 100644 index 000000000..794a5dc2d --- /dev/null +++ b/src/rules/mv_autoregressive/a.jl @@ -0,0 +1,26 @@ + +@rule MAR(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin + + order, ds = getorder(meta), getdimensionality(meta) + F = Multivariate + + dim = order*ds + + myx, Vyx = mean_cov(q_y_x) + my, Vy = ar_slice(F, myx, 1:dim), ar_slice(F, Vyx, 1:dim, 1:dim) + mx, Vx = ar_slice(F, myx, (dim+1):2dim), ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) + Vyx = ar_slice(F, Vyx, (dim+1):2dim, 1:dim) + mΛ = mean(q_Λ) + mW = mar_transition(order, mΛ) + mW + # this should be inside MARMeta + es = [uvector(dim, i) for i in 1:order] + Fs = [mask_mar(order, ds, i) for i in 1:order] + # @show Iterators.product(transpose.(es), mW, es) + # @show sum(prod, Iterators.product(transpose.(es), mW, es)) + # ∏ = Iterators.product(transpose.(es), mW, es, transpose.(Fs), (Vx + mx*mx'), Fs) + + D = sum(sum(es[i]'*mW*es[j]*Fs[i]'*(Vx + mx*mx')*Fs[j] for i in 1:order) for j in 1:order) + z = sum(Fs[i]'*(Vyx + my*mx')*mW*es[i] for i in 1:order) + return MvNormalMeanCovariance(inv(D)*z, inv(D)) +end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index e69de29bb..fd70e5594 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -0,0 +1,37 @@ +@rule MAR(:Λ, Marginalisation) ( + q_y_x::MultivariateNormalDistributionsFamily, + q_a::MultivariateNormalDistributionsFamily, + meta::MARMeta +) = begin + order, ds = getorder(meta), getdimensionality(meta) + F = Multivariate + dim = order*ds + + n = div(ndims(q_y_x), 2) + + y_x_mean, y_x_cov = mean_cov(q_y_x) + ma, Va = mean_cov(q_a) + + mA = mar_companion_matrix(order, ds, ma) + + myx, Vyx = mean_cov(q_y_x) + my, Vy = ar_slice(F, myx, 1:dim), ar_slice(F, Vyx, 1:dim, 1:dim) + mx, Vx = ar_slice(F, myx, (dim+1):2dim), ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) + Vyx = ar_slice(F, Vyx, (dim+1):2dim, 1:dim) + + # this should be inside MARMeta + es = [uvector(dim, i) for i in 1:order] + Fs = [mask_mar(order, ds, i) for i in 1:order] + + vmx = (Vx + mx*mx') + S = mar_shift(order, ds) + G₁ = S*vmx*S' + G₂ = sum(es[i]*ma'Fs[i]'vmx for i in 1:order)*S' + G₃ = transpose(G₂) + G₄ = sum(sum(es[i]*ma'Fs[i]'*vmx*Fs[j]*ma*es[j]' + es[i]*tr(Va*Fs[i]'*vmx*Fs[j])*es[j]' for i in 1:order) for j in 1:order) + G = G₁ + G₂ + G₃ + G₄ + + @show Δ = G + Vy + my*my' - (Vyx + my*mx')*mA' - mA*(Vyx'+ mx*my') + + return WishartMessage(n-2, Δ) +end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl index e69de29bb..c5cee4845 100644 --- a/src/rules/mv_autoregressive/marginals.jl +++ b/src/rules/mv_autoregressive/marginals.jl @@ -0,0 +1,56 @@ + +@marginalrule MAR(:y_x) ( + m_y::MultivariateNormalDistributionsFamily, + m_x::MultivariateNormalDistributionsFamily, + q_a::MultivariateNormalDistributionsFamily, + q_Λ::Any, + meta::MARMeta +) = begin + return ar_y_x_marginal(m_y, m_x, q_a, q_Λ, meta) +end + +function ar_y_x_marginal( + m_y::MultivariateNormalDistributionsFamily, + m_x::MultivariateNormalDistributionsFamily, + q_a::MultivariateNormalDistributionsFamily, + q_Λ::Any, + meta::MARMeta) + + order, ds = getorder(meta), getdimensionality(meta) + F = Multivariate + dim = order*ds + + + ma, Va = mean_cov(q_a) + mΛ = mean(q_Λ) + + mA = mar_companion_matrix(order, ds, ma) + mW = mar_transition(getorder(meta), mΛ) + + b_my, b_Vy = mean_cov(m_y) + f_mx, f_Vx = mean_cov(m_x) + + inv_b_Vy = cholinv(b_Vy) + inv_f_Vx = cholinv(f_Vx) + + # this should be inside MARMeta + es = [uvector(dim, i) for i in 1:order] + Fs = [mask_mar(order, ds, i) for i in 1:order] + + Ξ = inv_f_Vx + sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:order) for j in 1:order) + + W_11 = inv_b_Vy + mW + + # negate_inplace!(mW * mA) + W_12 = -(mW * mA) + + # Equivalent to + W_21 = (-mA' * mW) + + W_22 = Ξ + mA' * mW * mA + + W = [W_11 W_12; W_21 W_22] + ξ = [inv_b_Vy * b_my; inv_f_Vx * f_mx] + + return MvNormalWeightedMeanPrecision(ξ, W) +end diff --git a/src/rules/mv_autoregressive/theta.jl b/src/rules/mv_autoregressive/theta.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index e69de29bb..cc0b66d9b 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -0,0 +1,29 @@ + +@rule MAR(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = +begin + ma, Va = mean_cov(q_a) + my, Vy = mean_cov(m_y) + + mΛ = mean(q_Λ) + + order, ds = getorder(meta), getdimensionality(meta) + dim = order*ds + + mA = mar_companion_matrix(order, ds, ma) + mW = mar_transition(getorder(meta), mΛ) + + # this should be inside MARMeta + es = [uvector(dim, i) for i in 1:order] + Fs = [mask_mar(order, ds, i) for i in 1:order] + + # ∏ = Iterators.product(transpose.(es), mW, es, Fs, Va, transpose.(Fs)) + Σ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:order) for j in 1:order) + + Ξ = mA'*inv(Vy + inv(mW))*mA + inv(Σ) + z = mA'*inv(Vy + inv(mW))*my + + mx = inv(Ξ)*z + Vx = inv(Ξ) + + return MvNormalMeanCovariance(mx, Vx) +end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl index 2ad339d7d..fbdc3047d 100644 --- a/src/rules/mv_autoregressive/y.jl +++ b/src/rules/mv_autoregressive/y.jl @@ -1,27 +1,27 @@ - -@rule MAR(:y, Marginalisation) (m_x::NormalDistributionsFamily, q_θ::NormalDistributionsFamily, q_Λ::Any, meta::ARMeta) = +@rule MAR(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin - mθ, Vθ = mean_cov(q_θ) + ma, Va = mean_cov(q_a) mx, Wx = mean_invcov(m_x) mΛ = mean(q_Λ) - mA = as_companion_matrix(mθ) - mV = ar_transition(getvform(meta), getorder(meta), mγ) - - D = Wx + mγ * Vθ - C = mA * inv(D) + order, ds = getorder(meta), getdimensionality(meta) - my = C * Wx * mx - Vy = add_transition!(C * mA', mV) + mA = mar_companion_matrix(order, ds, ma) + mW = mar_transition(getorder(meta), mΛ) + dim = order*ds + # this should be inside MARMeta + es = [uvector(dim, i) for i in 1:order] + Fs = [mask_mar(order, ds, i) for i in 1:order] + + Σ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:order) for j in 1:order) - return convert(promote_variate_type(getvform(meta), NormalMeanVariance), my, Vy) -end -@rule AR(:y, Marginalisation) (q_x::Any, q_θ::Any, q_γ::Any, meta::ARMeta) = begin -mA = as_companion_matrix(mean(q_θ)) + Ξ = inv(Σ) + Wx + z = Wx*mx -mV = ar_transition(getvform(meta), getorder(meta), mean(q_γ)) + Vy = mA*inv(Ξ)*mA' + inv(Wx) + my = mA*inv(Ξ)*z -return convert(promote_variate_type(getvform(meta), NormalMeanVariance), mA * mean(q_x), mV) -end + return MvNormalMeanCovariance(my, Vy) +end \ No newline at end of file diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl index 1e7e46a00..c33fbd77e 100644 --- a/src/rules/prototypes.jl +++ b/src/rules/prototypes.jl @@ -115,6 +115,12 @@ include("autoregressive/theta.jl") include("autoregressive/gamma.jl") include("autoregressive/marginals.jl") +include("mv_autoregressive/y.jl") +include("mv_autoregressive/x.jl") +include("mv_autoregressive/a.jl") +include("mv_autoregressive/lambda.jl") +include("mv_autoregressive/marginals.jl") + include("probit/marginals.jl") include("probit/in.jl") include("probit/out.jl") From 3462700819488be74218ce5605209a44dc729087 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Mon, 31 Oct 2022 09:28:05 +0100 Subject: [PATCH 04/48] Update MAR --- src/nodes/mv_autoregressive.jl | 6 +++--- src/rules/mv_autoregressive/lambda.jl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index c61f0560b..f0efb8d44 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -22,7 +22,7 @@ end getorder(meta::MARMeta) = meta.order getdimensionality(meta::MARMeta) = meta.ds -@node MAR Stochastic [y, x, θ, Λ] +@node MAR Stochastic [y, x, a, Λ] default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag explicitly specified") @@ -59,11 +59,11 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl g₃ = g₂ G = sum(sum(Fs[i]*(ma*es[i]'*mΛ*es[j]*ma' + Va*es[i]'*mΛ*es[j])*Fs[j]' for i in 1:order) for j in 1:order) g₄ = mx'*G*mx + tr(Vx*G) - AE = -mean(logdet, q_Λ) + n/2*log2π + 0.5 + g₁ + g₂ + g₃ + g₄ + AE = mean(logdet, q_Λ) - n/2*log2π + 0.5 + g₁ + g₂ + g₃ + g₄ if order > 1 AE += entropy(q_y_x) - @show idc = LazyArrays.Vcat(1:order, (dim+1):2dim) + idc = LazyArrays.Vcat(1:order, (dim+1):2dim) myx_n = view(myx, idc) Vyx_n = view(Vyx, idc, idc) q_y_x = MvNormalMeanCovariance(myx_n, Vyx_n) diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index fd70e5594..dafd6f435 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -31,7 +31,7 @@ G₄ = sum(sum(es[i]*ma'Fs[i]'*vmx*Fs[j]*ma*es[j]' + es[i]*tr(Va*Fs[i]'*vmx*Fs[j])*es[j]' for i in 1:order) for j in 1:order) G = G₁ + G₂ + G₃ + G₄ - @show Δ = G + Vy + my*my' - (Vyx + my*mx')*mA' - mA*(Vyx'+ mx*my') - - return WishartMessage(n-2, Δ) + Δ = G + Vy + my*my' - (Vyx + my*mx')*mA' - mA*(Vyx'+ mx*my') + # TODO check for n + return WishartMessage(n+2, Δ[1:order, 1:order]) end \ No newline at end of file From 91179d1ce6f57a7f6194044c506cbf9e2678b312 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Mon, 28 Nov 2022 02:24:05 -0600 Subject: [PATCH 05/48] Update rules --- src/rules/mv_autoregressive/a.jl | 6 +++--- src/rules/mv_autoregressive/lambda.jl | 15 ++++++++------- src/rules/mv_autoregressive/x.jl | 4 ++-- src/rules/mv_autoregressive/y.jl | 4 ++-- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index 794a5dc2d..b14d35468 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -12,7 +12,7 @@ Vyx = ar_slice(F, Vyx, (dim+1):2dim, 1:dim) mΛ = mean(q_Λ) mW = mar_transition(order, mΛ) - mW + # this should be inside MARMeta es = [uvector(dim, i) for i in 1:order] Fs = [mask_mar(order, ds, i) for i in 1:order] @@ -20,7 +20,7 @@ # @show sum(prod, Iterators.product(transpose.(es), mW, es)) # ∏ = Iterators.product(transpose.(es), mW, es, transpose.(Fs), (Vx + mx*mx'), Fs) - D = sum(sum(es[i]'*mW*es[j]*Fs[i]'*(Vx + mx*mx')*Fs[j] for i in 1:order) for j in 1:order) - z = sum(Fs[i]'*(Vyx + my*mx')*mW*es[i] for i in 1:order) + D = sum(sum(es[i]'*mW*es[j]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:order) for j in 1:order) + z = sum(Fs[i]'*(mx*my'+Vyx')*mW*es[i] for i in 1:order) return MvNormalMeanCovariance(inv(D)*z, inv(D)) end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index dafd6f435..62a0b633f 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -23,15 +23,16 @@ es = [uvector(dim, i) for i in 1:order] Fs = [mask_mar(order, ds, i) for i in 1:order] - vmx = (Vx + mx*mx') S = mar_shift(order, ds) - G₁ = S*vmx*S' - G₂ = sum(es[i]*ma'Fs[i]'vmx for i in 1:order)*S' + G₁ = S*Vx*S' + G₂ = sum(S*Vx*Fs[i]*ma*es[i]' for i in 1:order) G₃ = transpose(G₂) - G₄ = sum(sum(es[i]*ma'Fs[i]'*vmx*Fs[j]*ma*es[j]' + es[i]*tr(Va*Fs[i]'*vmx*Fs[j])*es[j]' for i in 1:order) for j in 1:order) - G = G₁ + G₂ + G₃ + G₄ + G₄ = sum(sum(es[i]*ma'*Fs[i]'*Vx*Fs[j]*ma*es[j]' for i in 1:order) for j in 1:order) + G₅ = sum(sum(es[i]*mx'*Fs[j]*Va*Fs[i]'*mx*es[j]' for i in 1:order) for j in 1:order) + G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Vx*Fs[j])*es[j]' for i in 1:order) for j in 1:order) + G = G₁ + G₂ + G₃ + G₄ + G₅ + G₆ - Δ = G + Vy + my*my' - (Vyx + my*mx')*mA' - mA*(Vyx'+ mx*my') - # TODO check for n + Δ = (my - mA*mx)*(my - mA*mx)' - mA*Vyx' - Vyx*mA + S*Vx*S' + G + return WishartMessage(n+2, Δ[1:order, 1:order]) end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index cc0b66d9b..bbd54c936 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -17,9 +17,9 @@ begin Fs = [mask_mar(order, ds, i) for i in 1:order] # ∏ = Iterators.product(transpose.(es), mW, es, Fs, Va, transpose.(Fs)) - Σ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:order) for j in 1:order) + Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:order) for j in 1:order) - Ξ = mA'*inv(Vy + inv(mW))*mA + inv(Σ) + Ξ = mA*inv(inv(mW) + Vy)*mA' + Λ z = mA'*inv(Vy + inv(mW))*my mx = inv(Ξ)*z diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl index fbdc3047d..d07bd91fe 100644 --- a/src/rules/mv_autoregressive/y.jl +++ b/src/rules/mv_autoregressive/y.jl @@ -14,10 +14,10 @@ begin es = [uvector(dim, i) for i in 1:order] Fs = [mask_mar(order, ds, i) for i in 1:order] - Σ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:order) for j in 1:order) + Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:order) for j in 1:order) - Ξ = inv(Σ) + Wx + Ξ = Λ + Wx z = Wx*mx Vy = mA*inv(Ξ)*mA' + inv(Wx) From 6795aff07c6e74604ec50a901a55f9088d7dc549 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 30 Dec 2022 19:11:52 +0100 Subject: [PATCH 06/48] Update rules --- src/rules/mv_autoregressive/lambda.jl | 23 +++++++++++++++-------- src/rules/mv_autoregressive/y.jl | 2 +- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index 62a0b633f..95432aed7 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -24,15 +24,22 @@ Fs = [mask_mar(order, ds, i) for i in 1:order] S = mar_shift(order, ds) - G₁ = S*Vx*S' - G₂ = sum(S*Vx*Fs[i]*ma*es[i]' for i in 1:order) + # G₁ = S*Vx*S' + # G₂ = sum(S*Vx*Fs[i]*ma*es[i]' for i in 1:order) + # G₃ = transpose(G₂) + # G₄ = sum(sum(es[i]*ma'*Fs[i]'*Vx*Fs[j]*ma*es[j]' for i in 1:order) for j in 1:order) + # G₅ = sum(sum(es[i]*mx'*Fs[j]*Va*Fs[i]'*mx*es[j]' for i in 1:order) for j in 1:order) + # G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Vx*Fs[j])*es[j]' for i in 1:order) for j in 1:order) + G₁ = Vy[1:order, 1:order] + G₂ = (my*mx'*mA')[1:order, 1:order] G₃ = transpose(G₂) - G₄ = sum(sum(es[i]*ma'*Fs[i]'*Vx*Fs[j]*ma*es[j]' for i in 1:order) for j in 1:order) - G₅ = sum(sum(es[i]*mx'*Fs[j]*Va*Fs[i]'*mx*es[j]' for i in 1:order) for j in 1:order) - G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Vx*Fs[j])*es[j]' for i in 1:order) for j in 1:order) - G = G₁ + G₂ + G₃ + G₄ + G₅ + G₆ + Ex_xx = mx*mx' + Vx + G₅ = sum(sum(es[i]*ma'*Fs[j]'Ex_xx*Fs[i]*ma*es[j]' for i in 1:order) for j in 1:order)[1:order, 1:order] + G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Ex_xx*Fs[j])*es[j]' for i in 1:order) for j in 1:order)[1:order, 1:order] + Δ = G₁ + G₂ + G₃ + G₅ + G₆ + # G = G₁ + G₂ + G₃ + G₄ + G₅ + G₆ - Δ = (my - mA*mx)*(my - mA*mx)' - mA*Vyx' - Vyx*mA + S*Vx*S' + G + # Δ = (my - mA*mx)*(my - mA*mx)' - mA*Vyx' - Vyx*mA + S*Vx*S' + G - return WishartMessage(n+2, Δ[1:order, 1:order]) + return WishartMessage(n+2, Δ) end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl index d07bd91fe..b46dd141a 100644 --- a/src/rules/mv_autoregressive/y.jl +++ b/src/rules/mv_autoregressive/y.jl @@ -20,7 +20,7 @@ begin Ξ = Λ + Wx z = Wx*mx - Vy = mA*inv(Ξ)*mA' + inv(Wx) + Vy = mA*inv(Ξ)*mA' + inv(mW) my = mA*inv(Ξ)*z return MvNormalMeanCovariance(my, Vy) From bd7be22f78bb48f0455827cbacf8394c97ba1f73 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Mon, 2 Jan 2023 18:12:37 +0100 Subject: [PATCH 07/48] WIP: Update mask MAR function --- src/nodes/mv_autoregressive.jl | 51 +++++++++++++++++++++------ src/rules/mv_autoregressive/lambda.jl | 11 +----- src/rules/mv_autoregressive/x.jl | 1 - 3 files changed, 41 insertions(+), 22 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index f0efb8d44..8608cf3cd 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -1,6 +1,6 @@ export MAR, MvAutoregressive, MARMeta, mar_transition, mar_shift -import LazyArrays +import LazyArrays, BlockArrays import StatsFuns: log2π struct MAR end @@ -55,13 +55,14 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl # # Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2 g₁ = my1'*mΛ*my1 + tr(Vy1*mΛ) - g₂ = mx'*mA'*mΛ*my1 + tr(Vy1x*mA'*mΛ) - g₃ = g₂ - G = sum(sum(Fs[i]*(ma*es[i]'*mΛ*es[j]*ma' + Va*es[i]'*mΛ*es[j])*Fs[j]' for i in 1:order) for j in 1:order) + g₂ = -mx'*mA'*mΛ*my1 + tr(Vy1x*mA'*mΛ) + g₃ = -g₂ + G = sum(sum(es[i]'*mΛ*es[j]*Fs[i]*(ma*ma' + Va)*Fs[j]' for i in 1:order) for j in 1:order) g₄ = mx'*G*mx + tr(Vx*G) - AE = mean(logdet, q_Λ) - n/2*log2π + 0.5 + g₁ + g₂ + g₃ + g₄ + AE = n/2*log2π - 0.5*mean(logdet, q_Λ) + 0.5*(g₁ + g₂ + g₃ + g₄) if order > 1 + mean(q_y_x) AE += entropy(q_y_x) idc = LazyArrays.Vcat(1:order, (dim+1):2dim) myx_n = view(myx, idc) @@ -76,12 +77,40 @@ end # Helpers for AR rules -function mask_mar(order, ds, index) - theta_len = order * order * ds - F = zeros(order * ds, theta_len) - F[1:order, order*index-1:order*index] = diageye(order) - F[ds+1:end, order*index+ds+1:order*index+ds+order] = diageye(order) - return F +# p, d, i +# function mask_mar(order, ds, index) +# Frows = order * ds +# Fcols = ds * Frows +# F = zeros(Frows, Fcols) +# FB = BlockArray(F, ) +# for k in 1:ds*order +# for j in 1:ds*order^2 +# if j == ds*(index+(k-1)*ds) +# F[k, j] = 1.0 +# else +# F[k, j] = .0 +# end +# end +# end +# # F[1:order, ds*index-1:ds*index] = diageye(ds) +# # F[ds+1:end, ds*index+ds+1:ds*index+ds+ds] = diageye(ds) +# @show F +# return F +# end + +function mask_mar(p, d, index) + F = zeros(d*p, d*d*p) + rows = repeat([d], p) + cols = repeat([d], d*p) + FB = BlockArrays.BlockArray(F, rows, cols) + for k in 1:p + for j in 1:d*p + if j == index + (k-1)*d + view(FB, BlockArrays.Block(k, j)) .= diageye(d) + end + end + end + return Matrix(FB) end function mar_transition(order, Λ) diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index 95432aed7..4cbb5017a 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -19,17 +19,11 @@ mx, Vx = ar_slice(F, myx, (dim+1):2dim), ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) Vyx = ar_slice(F, Vyx, (dim+1):2dim, 1:dim) - # this should be inside MARMeta es = [uvector(dim, i) for i in 1:order] Fs = [mask_mar(order, ds, i) for i in 1:order] S = mar_shift(order, ds) - # G₁ = S*Vx*S' - # G₂ = sum(S*Vx*Fs[i]*ma*es[i]' for i in 1:order) - # G₃ = transpose(G₂) - # G₄ = sum(sum(es[i]*ma'*Fs[i]'*Vx*Fs[j]*ma*es[j]' for i in 1:order) for j in 1:order) - # G₅ = sum(sum(es[i]*mx'*Fs[j]*Va*Fs[i]'*mx*es[j]' for i in 1:order) for j in 1:order) - # G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Vx*Fs[j])*es[j]' for i in 1:order) for j in 1:order) + G₁ = Vy[1:order, 1:order] G₂ = (my*mx'*mA')[1:order, 1:order] G₃ = transpose(G₂) @@ -37,9 +31,6 @@ G₅ = sum(sum(es[i]*ma'*Fs[j]'Ex_xx*Fs[i]*ma*es[j]' for i in 1:order) for j in 1:order)[1:order, 1:order] G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Ex_xx*Fs[j])*es[j]' for i in 1:order) for j in 1:order)[1:order, 1:order] Δ = G₁ + G₂ + G₃ + G₅ + G₆ - # G = G₁ + G₂ + G₃ + G₄ + G₅ + G₆ - - # Δ = (my - mA*mx)*(my - mA*mx)' - mA*Vyx' - Vyx*mA + S*Vx*S' + G return WishartMessage(n+2, Δ) end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index bbd54c936..18cc75bc6 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -16,7 +16,6 @@ begin es = [uvector(dim, i) for i in 1:order] Fs = [mask_mar(order, ds, i) for i in 1:order] - # ∏ = Iterators.product(transpose.(es), mW, es, Fs, Va, transpose.(Fs)) Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:order) for j in 1:order) Ξ = mA*inv(inv(mW) + Vy)*mA' + Λ From 86a24e46636ea1f6eb6fbefab877334b64a82757 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Mon, 2 Jan 2023 20:55:07 +0100 Subject: [PATCH 08/48] Bug fix --- src/nodes/mv_autoregressive.jl | 9 ++++----- src/rules/mv_autoregressive/a.jl | 8 ++++---- src/rules/mv_autoregressive/lambda.jl | 14 ++++++-------- src/rules/mv_autoregressive/marginals.jl | 6 +++--- src/rules/mv_autoregressive/x.jl | 6 +++--- src/rules/mv_autoregressive/y.jl | 6 +++--- 6 files changed, 23 insertions(+), 26 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index 8608cf3cd..a3a8bcdb7 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -43,21 +43,20 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl ma, Va = mean_cov(q_a) - mA = mar_companion_matrix(order, ds, ma)[1:order, 1:dim] - + mA = mar_companion_matrix(order, ds, ma)[1:ds, 1:dim] mx, Vx = ar_slice(F, myx, (dim+1):2dim), ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) my1, Vy1 = myx[1:ds], Vyx[1:ds, 1:ds] Vy1x = ar_slice(F, Vyx, 1:ds, dim+1:2dim) # this should be inside MARMeta - es = [uvector(order, i) for i in 1:order] - Fs = [mask_mar(order, ds, i) for i in 1:order] + es = [uvector(ds, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] # # Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2 g₁ = my1'*mΛ*my1 + tr(Vy1*mΛ) g₂ = -mx'*mA'*mΛ*my1 + tr(Vy1x*mA'*mΛ) g₃ = -g₂ - G = sum(sum(es[i]'*mΛ*es[j]*Fs[i]*(ma*ma' + Va)*Fs[j]' for i in 1:order) for j in 1:order) + G = sum(sum(es[i]'*mΛ*es[j]*Fs[i]*(ma*ma' + Va)*Fs[j]' for i in 1:ds) for j in 1:ds) g₄ = mx'*G*mx + tr(Vx*G) AE = n/2*log2π - 0.5*mean(logdet, q_Λ) + 0.5*(g₁ + g₂ + g₃ + g₄) diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index b14d35468..dee6b9037 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -14,13 +14,13 @@ mW = mar_transition(order, mΛ) # this should be inside MARMeta - es = [uvector(dim, i) for i in 1:order] - Fs = [mask_mar(order, ds, i) for i in 1:order] + es = [uvector(dim, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] # @show Iterators.product(transpose.(es), mW, es) # @show sum(prod, Iterators.product(transpose.(es), mW, es)) # ∏ = Iterators.product(transpose.(es), mW, es, transpose.(Fs), (Vx + mx*mx'), Fs) - D = sum(sum(es[i]'*mW*es[j]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:order) for j in 1:order) - z = sum(Fs[i]'*(mx*my'+Vyx')*mW*es[i] for i in 1:order) + D = sum(sum(es[i]'*mW*es[j]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) + z = sum(Fs[i]'*(mx*my'+Vyx')*mW*es[i] for i in 1:ds) return MvNormalMeanCovariance(inv(D)*z, inv(D)) end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index 4cbb5017a..50563800f 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -19,17 +19,15 @@ mx, Vx = ar_slice(F, myx, (dim+1):2dim), ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) Vyx = ar_slice(F, Vyx, (dim+1):2dim, 1:dim) - es = [uvector(dim, i) for i in 1:order] - Fs = [mask_mar(order, ds, i) for i in 1:order] - + es = [uvector(dim, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] S = mar_shift(order, ds) - - G₁ = Vy[1:order, 1:order] - G₂ = (my*mx'*mA')[1:order, 1:order] + G₁ = Vy[1:ds, 1:ds] + G₂ = (my*mx'*mA')[1:ds, 1:ds] G₃ = transpose(G₂) Ex_xx = mx*mx' + Vx - G₅ = sum(sum(es[i]*ma'*Fs[j]'Ex_xx*Fs[i]*ma*es[j]' for i in 1:order) for j in 1:order)[1:order, 1:order] - G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Ex_xx*Fs[j])*es[j]' for i in 1:order) for j in 1:order)[1:order, 1:order] + G₅ = sum(sum(es[i]*ma'*Fs[j]'Ex_xx*Fs[i]*ma*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] + G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Ex_xx*Fs[j])*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] Δ = G₁ + G₂ + G₃ + G₅ + G₆ return WishartMessage(n+2, Δ) diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl index c5cee4845..99778ce49 100644 --- a/src/rules/mv_autoregressive/marginals.jl +++ b/src/rules/mv_autoregressive/marginals.jl @@ -34,10 +34,10 @@ function ar_y_x_marginal( inv_f_Vx = cholinv(f_Vx) # this should be inside MARMeta - es = [uvector(dim, i) for i in 1:order] - Fs = [mask_mar(order, ds, i) for i in 1:order] + es = [uvector(dim, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] - Ξ = inv_f_Vx + sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:order) for j in 1:order) + Ξ = inv_f_Vx + sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) W_11 = inv_b_Vy + mW diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index 18cc75bc6..6c626b46e 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -13,10 +13,10 @@ begin mW = mar_transition(getorder(meta), mΛ) # this should be inside MARMeta - es = [uvector(dim, i) for i in 1:order] - Fs = [mask_mar(order, ds, i) for i in 1:order] + es = [uvector(dim, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] - Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:order) for j in 1:order) + Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) Ξ = mA*inv(inv(mW) + Vy)*mA' + Λ z = mA'*inv(Vy + inv(mW))*my diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl index b46dd141a..a67d7ba1c 100644 --- a/src/rules/mv_autoregressive/y.jl +++ b/src/rules/mv_autoregressive/y.jl @@ -11,10 +11,10 @@ begin mW = mar_transition(getorder(meta), mΛ) dim = order*ds # this should be inside MARMeta - es = [uvector(dim, i) for i in 1:order] - Fs = [mask_mar(order, ds, i) for i in 1:order] + es = [uvector(dim, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] - Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:order) for j in 1:order) + Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) Ξ = Λ + Wx From dcb5bd5c0305198241be2f3c499cdab8b9fc7292 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 3 Jan 2023 14:21:08 +0100 Subject: [PATCH 09/48] Update rules --- src/nodes/mv_autoregressive.jl | 23 ----------------------- src/rules/mv_autoregressive/a.jl | 8 +++----- src/rules/mv_autoregressive/x.jl | 7 ++----- 3 files changed, 5 insertions(+), 33 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index a3a8bcdb7..c992b710a 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -74,29 +74,6 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl end # Helpers for AR rules - - -# p, d, i -# function mask_mar(order, ds, index) -# Frows = order * ds -# Fcols = ds * Frows -# F = zeros(Frows, Fcols) -# FB = BlockArray(F, ) -# for k in 1:ds*order -# for j in 1:ds*order^2 -# if j == ds*(index+(k-1)*ds) -# F[k, j] = 1.0 -# else -# F[k, j] = .0 -# end -# end -# end -# # F[1:order, ds*index-1:ds*index] = diageye(ds) -# # F[ds+1:end, ds*index+ds+1:ds*index+ds+ds] = diageye(ds) -# @show F -# return F -# end - function mask_mar(p, d, index) F = zeros(d*p, d*d*p) rows = repeat([d], p) diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index dee6b9037..974fe2c87 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -16,11 +16,9 @@ # this should be inside MARMeta es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] - # @show Iterators.product(transpose.(es), mW, es) - # @show sum(prod, Iterators.product(transpose.(es), mW, es)) - # ∏ = Iterators.product(transpose.(es), mW, es, transpose.(Fs), (Vx + mx*mx'), Fs) - D = sum(sum(es[i]'*mW*es[j]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) + D = sum(sum(es[j]'*mW*es[i]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) z = sum(Fs[i]'*(mx*my'+Vyx')*mW*es[i] for i in 1:ds) - return MvNormalMeanCovariance(inv(D)*z, inv(D)) + + return MvNormalWeightedMeanPrecision(z, D) end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index 6c626b46e..a797d6907 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -18,11 +18,8 @@ begin Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) - Ξ = mA*inv(inv(mW) + Vy)*mA' + Λ + Ξ = mA'*inv(Vy + inv(mW))*mA + Λ z = mA'*inv(Vy + inv(mW))*my - mx = inv(Ξ)*z - Vx = inv(Ξ) - - return MvNormalMeanCovariance(mx, Vx) + return MvNormalWeightedMeanPrecision(z, Ξ) end \ No newline at end of file From adc58ad804b66e8ce494d9c8b26811ec31c41376 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 3 Jan 2023 16:59:37 +0100 Subject: [PATCH 10/48] Update rules --- src/nodes/autoregressive.jl | 4 ++-- src/nodes/mv_autoregressive.jl | 20 +++++++++----------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/nodes/autoregressive.jl b/src/nodes/autoregressive.jl index eb46bc3af..2e804e4d9 100644 --- a/src/nodes/autoregressive.jl +++ b/src/nodes/autoregressive.jl @@ -55,7 +55,7 @@ default_meta(::Type{AR}) = error("Autoregressive node requires meta flag explici # correction if is_multivariate(meta) - AE += entropy(q_y_x) + # AE += entropy(q_y_x) idc = LazyArrays.Vcat(1, (order + 1):(2order)) myx_n = view(myx, idc) Vyx_n = view(Vyx, idc, idc) @@ -80,7 +80,7 @@ end # correction if is_multivariate(meta) - AE += entropy(q_y) + # AE += entropy(q_y) q_y = NormalMeanVariance(my1, Vy1) AE -= entropy(q_y) end diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index c992b710a..0410698ef 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -52,7 +52,6 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl es = [uvector(ds, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] - # # Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2 g₁ = my1'*mΛ*my1 + tr(Vy1*mΛ) g₂ = -mx'*mA'*mΛ*my1 + tr(Vy1x*mA'*mΛ) g₃ = -g₂ @@ -61,8 +60,7 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl AE = n/2*log2π - 0.5*mean(logdet, q_Λ) + 0.5*(g₁ + g₂ + g₃ + g₄) if order > 1 - mean(q_y_x) - AE += entropy(q_y_x) + # AE += entropy(q_y_x) idc = LazyArrays.Vcat(1:order, (dim+1):2dim) myx_n = view(myx, idc) Vyx_n = view(Vyx, idc, idc) @@ -74,15 +72,15 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl end # Helpers for AR rules -function mask_mar(p, d, index) - F = zeros(d*p, d*d*p) - rows = repeat([d], p) - cols = repeat([d], d*p) +function mask_mar(order, dimension, index) + F = zeros(dimension*order, dimension*dimension*order) + rows = repeat([dimension], order) + cols = repeat([dimension], dimension*order) FB = BlockArrays.BlockArray(F, rows, cols) - for k in 1:p - for j in 1:d*p - if j == index + (k-1)*d - view(FB, BlockArrays.Block(k, j)) .= diageye(d) + for k in 1:order + for j in 1:dimension*order + if j == index + (k-1)*dimension + view(FB, BlockArrays.Block(k, j)) .= diageye(dimension) end end end From ee8b471762959e3a07aff685e4a2852deb42c118 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Wed, 4 Jan 2023 11:43:24 +0100 Subject: [PATCH 11/48] project: add BlockArrays dependency --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1e57aec78..0657f4a22 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Dmitry Bagaev ", "Albert Podusenko Date: Wed, 4 Jan 2023 11:54:49 +0100 Subject: [PATCH 12/48] fix constructor --- src/nodes/mv_autoregressive.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index 0410698ef..e84aa78ab 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -10,15 +10,16 @@ const MvAutoregressive = MAR struct MARMeta order :: Int # order (lag) of MAR ds :: Int # dimensionality of MAR process, i.e., the number of correlated AR processes -end -function MARMeta(order, ds=2) - if ds < 2 - @error "ds parameter should be > 1. Use AR node if ds = 1" + function MARMeta(order, ds=2) + if ds < 2 + @error "ds parameter should be > 1. Use AR node if ds = 1" + end + return new(order, ds) end - return MARMeta(order, ds) end + getorder(meta::MARMeta) = meta.order getdimensionality(meta::MARMeta) = meta.ds @@ -117,4 +118,4 @@ function mar_companion_matrix(order, ds, a) es = [uvector(dim, i) for i in 1:order] Fs = [mask_mar(order, ds, i) for i in 1:order] return S .+ sum(es[i]*a'*Fs[i]' for i in 1:order) -end \ No newline at end of file +end From 09140d13ea71ac5ac78656450894068f84b29a9c Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Wed, 4 Jan 2023 14:01:58 +0100 Subject: [PATCH 13/48] Update rules --- src/rules/mv_autoregressive/a.jl | 23 ++++++++++++++++++ src/rules/mv_autoregressive/lambda.jl | 34 +++++++++++++++++++++++++-- src/rules/mv_autoregressive/x.jl | 25 ++++++++++++++++++++ src/rules/mv_autoregressive/y.jl | 10 ++++++++ 4 files changed, 90 insertions(+), 2 deletions(-) diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index 974fe2c87..06d269c6a 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -20,5 +20,28 @@ D = sum(sum(es[j]'*mW*es[i]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) z = sum(Fs[i]'*(mx*my'+Vyx')*mW*es[i] for i in 1:ds) + return MvNormalWeightedMeanPrecision(z, D) +end + + +@rule MAR(:a, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin + + order, ds = getorder(meta), getdimensionality(meta) + F = Multivariate + + dim = order*ds + + my, Vy = mean_cov(q_y) + mx, Vx = mean_cov(q_x) + mΛ = mean(q_Λ) + mW = mar_transition(order, mΛ) + + # this should be inside MARMeta + es = [uvector(dim, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] + + D = sum(sum(es[j]'*mW*es[i]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) + z = sum(Fs[i]'*mx*my'*mW*es[i] for i in 1:ds) + return MvNormalWeightedMeanPrecision(z, D) end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index 50563800f..87d067902 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -22,8 +22,8 @@ es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] S = mar_shift(order, ds) - G₁ = Vy[1:ds, 1:ds] - G₂ = (my*mx'*mA')[1:ds, 1:ds] + G₁ = (my*my' + Vy)[1:ds, 1:ds] + G₂ = ((my*mx' + Vyx)*mA')[1:ds, 1:ds] G₃ = transpose(G₂) Ex_xx = mx*mx' + Vx G₅ = sum(sum(es[i]*ma'*Fs[j]'Ex_xx*Fs[i]*ma*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] @@ -31,4 +31,34 @@ Δ = G₁ + G₂ + G₃ + G₅ + G₆ return WishartMessage(n+2, Δ) +end + +@rule MAR(:Λ, Marginalisation) ( + q_y::MultivariateNormalDistributionsFamily, + q_x::MultivariateNormalDistributionsFamily, + q_a::MultivariateNormalDistributionsFamily, + meta::MARMeta +) = begin + order, ds = getorder(meta), getdimensionality(meta) + F = Multivariate + dim = order*ds + + my, Vy = mean_cov(q_y) + mx, Vx = mean_cov(q_x) + ma, Va = mean_cov(q_a) + + mA = mar_companion_matrix(order, ds, ma) + + es = [uvector(dim, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] + S = mar_shift(order, ds) + G₁ = (my*my' + Vy)[1:ds, 1:ds] + G₂ = (my*mx'*mA')[1:ds, 1:ds] + G₃ = transpose(G₂) + Ex_xx = mx*mx' + Vx + G₅ = sum(sum(es[i]*ma'*Fs[j]'Ex_xx*Fs[i]*ma*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] + G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Ex_xx*Fs[j])*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] + Δ = G₁ + G₂ + G₃ + G₅ + G₆ + + return WishartMessage(length(my)+2, Δ) end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index a797d6907..3e1965d09 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -21,5 +21,30 @@ begin Ξ = mA'*inv(Vy + inv(mW))*mA + Λ z = mA'*inv(Vy + inv(mW))*my + return MvNormalWeightedMeanPrecision(z, Ξ) +end + +@rule MAR(:x, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin + + ma, Va = mean_cov(q_a) + my, Vy = mean_cov(q_y) + + mΛ = mean(q_Λ) + + order, ds = getorder(meta), getdimensionality(meta) + dim = order*ds + + mA = mar_companion_matrix(order, ds, ma) + mW = mar_transition(getorder(meta), mΛ) + + # this should be inside MARMeta + es = [uvector(dim, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] + + Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) + + Ξ = mA'*mW*mA + Λ + z = mA'*mW*my + return MvNormalWeightedMeanPrecision(z, Ξ) end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl index a67d7ba1c..b701321ab 100644 --- a/src/rules/mv_autoregressive/y.jl +++ b/src/rules/mv_autoregressive/y.jl @@ -24,4 +24,14 @@ begin my = mA*inv(Ξ)*z return MvNormalMeanCovariance(my, Vy) +end + +@rule MAR(:y, Marginalisation) (q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin + + order, ds = getorder(meta), getdimensionality(meta) + + mA = mar_companion_matrix(order, ds, mean(q_a)) + mW = mar_transition(getorder(meta), mean(q_Λ)) + + return MvNormalMeanPrecision(mA * mean(q_x), mW) end \ No newline at end of file From 78c5216b54e712d66d051c065e9e34854d4e8bbe Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Wed, 4 Jan 2023 14:43:57 +0100 Subject: [PATCH 14/48] Update FE --- src/nodes/autoregressive.jl | 4 ++-- src/nodes/mv_autoregressive.jl | 42 ++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/nodes/autoregressive.jl b/src/nodes/autoregressive.jl index 2e804e4d9..eb46bc3af 100644 --- a/src/nodes/autoregressive.jl +++ b/src/nodes/autoregressive.jl @@ -55,7 +55,7 @@ default_meta(::Type{AR}) = error("Autoregressive node requires meta flag explici # correction if is_multivariate(meta) - # AE += entropy(q_y_x) + AE += entropy(q_y_x) idc = LazyArrays.Vcat(1, (order + 1):(2order)) myx_n = view(myx, idc) Vyx_n = view(Vyx, idc, idc) @@ -80,7 +80,7 @@ end # correction if is_multivariate(meta) - # AE += entropy(q_y) + AE += entropy(q_y) q_y = NormalMeanVariance(my1, Vy1) AE -= entropy(q_y) end diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index e84aa78ab..279e40864 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -72,6 +72,48 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl return AE end +@average_energy MAR ( + q_y::MultivariateNormalDistributionsFamily, + q_x::MultivariateNormalDistributionsFamily, + q_a::MultivariateNormalDistributionsFamily, + q_Λ::Wishart, + meta::MARMeta +) = begin + ma, Va = mean_cov(q_a) + my, Vy = mean_cov(q_y) + mx, Vx = mean_cov(q_y) + mΛ = mean(q_Λ) + + order, ds = getorder(meta), getdimensionality(meta) + F = Multivariate + dim = order*ds + n = dim + + ma, Va = mean_cov(q_a) + mA = mar_companion_matrix(order, ds, ma)[1:ds, 1:dim] + + my1, Vy1 = my[1:ds], Vy[1:ds, 1:ds] + + # this should be inside MARMeta + es = [uvector(ds, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] + + g₁ = my1'*mΛ*my1 + tr(Vy1*mΛ) + g₂ = -mx'*mA'*mΛ*my1 + g₃ = -g₂ + G = sum(sum(es[i]'*mΛ*es[j]*Fs[i]*(ma*ma' + Va)*Fs[j]' for i in 1:ds) for j in 1:ds) + g₄ = mx'*G*mx + tr(Vx*G) + AE = n/2*log2π - 0.5*mean(logdet, q_Λ) + 0.5*(g₁ + g₂ + g₃ + g₄) + + if order > 1 + # AE += entropy(q_y) + q_y = MvNormalMeanCovariance(my1, Vy1) + AE -= entropy(q_y) + end + + return AE +end + # Helpers for AR rules function mask_mar(order, dimension, index) F = zeros(dimension*order, dimension*dimension*order) From 67d47cc438a7d6fb3ca982bc382e1baa770f9d08 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Thu, 5 Jan 2023 16:38:09 +0100 Subject: [PATCH 15/48] Update rule --- src/rules/mv_autoregressive/a.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index 06d269c6a..052225252 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -16,9 +16,10 @@ # this should be inside MARMeta es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] - + S = mar_shift(order, ds) + D = sum(sum(es[j]'*mW*es[i]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) - z = sum(Fs[i]'*(mx*my'+Vyx')*mW*es[i] for i in 1:ds) + z = sum(Fs[i]'*((mx*mx'+Vx')*S' + mx*my'+Vyx')*mW*es[i] for i in 1:ds) return MvNormalWeightedMeanPrecision(z, D) end @@ -41,7 +42,7 @@ end Fs = [mask_mar(order, ds, i) for i in 1:ds] D = sum(sum(es[j]'*mW*es[i]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) - z = sum(Fs[i]'*mx*my'*mW*es[i] for i in 1:ds) + z = sum(Fs[i]'*((mx*mx'+Vx')*S' + mx*my)*mW*es[i] for i in 1:ds) return MvNormalWeightedMeanPrecision(z, D) end \ No newline at end of file From f2fd4d5dc7c3aa8ad049e5ff5b8317ba0e4e232b Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 6 Jan 2023 14:59:14 +0100 Subject: [PATCH 16/48] Update MAR rules --- src/nodes/mv_autoregressive.jl | 9 +++++---- src/rules/mv_autoregressive/a.jl | 22 +++++++++++++++------- src/rules/mv_autoregressive/lambda.jl | 2 +- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index 279e40864..3bf69a7ae 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -132,7 +132,7 @@ end function mar_transition(order, Λ) dim = size(Λ, 1) - W = 1e12*diageye(dim*order) + W = huge*diageye(dim*order) W[1:dim, 1:dim] = Λ return W end @@ -157,7 +157,8 @@ end function mar_companion_matrix(order, ds, a) dim = order*ds S = mar_shift(order, ds) - es = [uvector(dim, i) for i in 1:order] - Fs = [mask_mar(order, ds, i) for i in 1:order] - return S .+ sum(es[i]*a'*Fs[i]' for i in 1:order) + es = [uvector(dim, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] + L = S .+ sum(es[i]*a'*Fs[i]' for i in 1:ds) + return L end diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index 052225252..861afa6d3 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -6,10 +6,14 @@ dim = order*ds - myx, Vyx = mean_cov(q_y_x) - my, Vy = ar_slice(F, myx, 1:dim), ar_slice(F, Vyx, 1:dim, 1:dim) - mx, Vx = ar_slice(F, myx, (dim+1):2dim), ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) - Vyx = ar_slice(F, Vyx, (dim+1):2dim, 1:dim) + m, V = mean_cov(q_y_x) + # @show V + + my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) + mx, Vx = ar_slice(F, m, (dim+1):2dim), ar_slice(F, V, (dim+1):2dim, (dim+1):2dim) + Vyx = ar_slice(F, V, (dim+1):2dim, (dim+1):2dim) + + mΛ = mean(q_Λ) mW = mar_transition(order, mΛ) @@ -17,9 +21,13 @@ es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] S = mar_shift(order, ds) - - D = sum(sum(es[j]'*mW*es[i]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) - z = sum(Fs[i]'*((mx*mx'+Vx')*S' + mx*my'+Vyx')*mW*es[i] for i in 1:ds) + # NOTE: prove that sum(Fs[i]'*((mx*mx'+Vx')*S')*mW*es[i] for i in 1:ds) == 0.0 + + D = sum(sum(es[i]'*mW*es[j]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) + # z = sum(Fs[i]'*((mx*mx'+Vx')*S' + mx*my'+Vyx')*mW*es[i] for i in 1:ds) + z = sum(Fs[i]'*(mx*my'+Vyx')*mW*es[i] for i in 1:ds) + + # @show mx*my' return MvNormalWeightedMeanPrecision(z, D) end diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index 87d067902..3316ff70c 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -17,7 +17,7 @@ myx, Vyx = mean_cov(q_y_x) my, Vy = ar_slice(F, myx, 1:dim), ar_slice(F, Vyx, 1:dim, 1:dim) mx, Vx = ar_slice(F, myx, (dim+1):2dim), ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) - Vyx = ar_slice(F, Vyx, (dim+1):2dim, 1:dim) + Vyx = ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] From 5eab7be67e7cec84a33f04a82b3b4be1a3cd84c8 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 6 Jan 2023 18:35:23 +0100 Subject: [PATCH 17/48] Update rules --- src/rules/mv_autoregressive/a.jl | 6 ++---- src/rules/mv_autoregressive/lambda.jl | 13 +++++++------ src/rules/mv_autoregressive/y.jl | 2 +- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index 861afa6d3..a35356c2a 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -7,11 +7,11 @@ dim = order*ds m, V = mean_cov(q_y_x) - # @show V my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) mx, Vx = ar_slice(F, m, (dim+1):2dim), ar_slice(F, V, (dim+1):2dim, (dim+1):2dim) - Vyx = ar_slice(F, V, (dim+1):2dim, (dim+1):2dim) + Vyx = ar_slice(F, V, 1:dim, dim+1:2dim) + mΛ = mean(q_Λ) @@ -27,8 +27,6 @@ # z = sum(Fs[i]'*((mx*mx'+Vx')*S' + mx*my'+Vyx')*mW*es[i] for i in 1:ds) z = sum(Fs[i]'*(mx*my'+Vyx')*mW*es[i] for i in 1:ds) - # @show mx*my' - return MvNormalWeightedMeanPrecision(z, D) end diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index 3316ff70c..2218daf23 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -9,15 +9,15 @@ n = div(ndims(q_y_x), 2) - y_x_mean, y_x_cov = mean_cov(q_y_x) + # y_x_mean, y_x_cov = mean_cov(q_y_x) ma, Va = mean_cov(q_a) mA = mar_companion_matrix(order, ds, ma) - myx, Vyx = mean_cov(q_y_x) - my, Vy = ar_slice(F, myx, 1:dim), ar_slice(F, Vyx, 1:dim, 1:dim) - mx, Vx = ar_slice(F, myx, (dim+1):2dim), ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) - Vyx = ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) + m, V = mean_cov(q_y_x) + my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) + mx, Vx = ar_slice(F, m, (dim+1):2dim), ar_slice(F, V, (dim+1):2dim, (dim+1):2dim) + Vyx = ar_slice(F, V, 1:dim, dim+1:2dim) es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] @@ -28,9 +28,10 @@ Ex_xx = mx*mx' + Vx G₅ = sum(sum(es[i]*ma'*Fs[j]'Ex_xx*Fs[i]*ma*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Ex_xx*Fs[j])*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - Δ = G₁ + G₂ + G₃ + G₅ + G₆ + @show Δ = G₁ - G₂ - G₃ + G₅ + G₆ return WishartMessage(n+2, Δ) + end @rule MAR(:Λ, Marginalisation) ( diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl index b701321ab..2a3ffa678 100644 --- a/src/rules/mv_autoregressive/y.jl +++ b/src/rules/mv_autoregressive/y.jl @@ -20,7 +20,7 @@ begin Ξ = Λ + Wx z = Wx*mx - Vy = mA*inv(Ξ)*mA' + inv(mW) + Vy = mA*inv(Ξ)*mA' + inv(mW) my = mA*inv(Ξ)*z return MvNormalMeanCovariance(my, Vy) From 395ea370a047c85d8663b4e3abf15f274b753b08 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sat, 7 Jan 2023 00:05:37 +0100 Subject: [PATCH 18/48] WIP: Update marginals & lambda --- src/rules/mv_autoregressive/lambda.jl | 6 +++--- src/rules/mv_autoregressive/marginals.jl | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index 2218daf23..171ebd03f 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -26,9 +26,9 @@ G₂ = ((my*mx' + Vyx)*mA')[1:ds, 1:ds] G₃ = transpose(G₂) Ex_xx = mx*mx' + Vx - G₅ = sum(sum(es[i]*ma'*Fs[j]'Ex_xx*Fs[i]*ma*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Ex_xx*Fs[j])*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - @show Δ = G₁ - G₂ - G₃ + G₅ + G₆ + G₅ = sum(sum(es[i]*ma'*Fs[i]'Ex_xx*Fs[j]*ma*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] + G₆ = sum(sum(es[i]*tr(Fs[i]'*Ex_xx*Fs[j]*Va)*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] + Δ = G₁ - G₂ - G₃ + G₅ + G₆ return WishartMessage(n+2, Δ) diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl index 99778ce49..4280bc800 100644 --- a/src/rules/mv_autoregressive/marginals.jl +++ b/src/rules/mv_autoregressive/marginals.jl @@ -44,8 +44,7 @@ function ar_y_x_marginal( # negate_inplace!(mW * mA) W_12 = -(mW * mA) - # Equivalent to - W_21 = (-mA' * mW) + W_21 = (-mA' * mW') W_22 = Ξ + mA' * mW * mA From 272581348558b35b4ce68d7d095c15b09e594f34 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sat, 7 Jan 2023 19:46:42 +0100 Subject: [PATCH 19/48] Fix bug --- src/rules/mv_autoregressive/a.jl | 2 -- src/rules/mv_autoregressive/lambda.jl | 13 +++++++------ src/rules/mv_autoregressive/marginals.jl | 2 +- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index a35356c2a..b12b25e31 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -11,8 +11,6 @@ my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) mx, Vx = ar_slice(F, m, (dim+1):2dim), ar_slice(F, V, (dim+1):2dim, (dim+1):2dim) Vyx = ar_slice(F, V, 1:dim, dim+1:2dim) - - mΛ = mean(q_Λ) mW = mar_transition(order, mΛ) diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index 171ebd03f..ad21ce759 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -6,8 +6,8 @@ order, ds = getorder(meta), getdimensionality(meta) F = Multivariate dim = order*ds - - n = div(ndims(q_y_x), 2) + + # n = div(ndims(q_y_x), 2) # y_x_mean, y_x_cov = mean_cov(q_y_x) ma, Va = mean_cov(q_a) @@ -24,13 +24,14 @@ S = mar_shift(order, ds) G₁ = (my*my' + Vy)[1:ds, 1:ds] G₂ = ((my*mx' + Vyx)*mA')[1:ds, 1:ds] - G₃ = transpose(G₂) + # G₃ = transpose(G₂) + G₃ = (mA*(mx*my' + Vyx'))[1:ds, 1:ds] Ex_xx = mx*mx' + Vx G₅ = sum(sum(es[i]*ma'*Fs[i]'Ex_xx*Fs[j]*ma*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] G₆ = sum(sum(es[i]*tr(Fs[i]'*Ex_xx*Fs[j]*Va)*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] Δ = G₁ - G₂ - G₃ + G₅ + G₆ - return WishartMessage(n+2, Δ) + return WishartMessage(ds+2, Δ) end @@ -59,7 +60,7 @@ end Ex_xx = mx*mx' + Vx G₅ = sum(sum(es[i]*ma'*Fs[j]'Ex_xx*Fs[i]*ma*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Ex_xx*Fs[j])*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - Δ = G₁ + G₂ + G₃ + G₅ + G₆ + Δ = G₁ - G₂ - G₃ + G₅ + G₆ - return WishartMessage(length(my)+2, Δ) + return WishartMessage(length(my)+2, inv(Δ)) end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl index 4280bc800..c97c34861 100644 --- a/src/rules/mv_autoregressive/marginals.jl +++ b/src/rules/mv_autoregressive/marginals.jl @@ -44,7 +44,7 @@ function ar_y_x_marginal( # negate_inplace!(mW * mA) W_12 = -(mW * mA) - W_21 = (-mA' * mW') + W_21 = (-mA' * mW) W_22 = Ξ + mA' * mW * mA From cfc4243e2310ec2d48e4e9238be4ea113810192c Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Mon, 9 Jan 2023 14:46:23 +0100 Subject: [PATCH 20/48] Fix backward rule --- src/rules/mv_autoregressive/x.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index 3e1965d09..929f26d22 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -18,10 +18,13 @@ begin Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) - Ξ = mA'*inv(Vy + inv(mW))*mA + Λ - z = mA'*inv(Vy + inv(mW))*my + Σ₁ = pinv(mA)*(Vy)*pinv(mA)' + pinv(mA'*mW*mA) + @show Σ₁ == pinv(mA)*(Vy + inv(mW))*pinv(mA)' + Σ₂ = inv(Λ) + θ = pinv(pinv(Σ₁) + inv(Σ₂)) + z = θ*pinv(Σ₁)*pinv(mA)*my - return MvNormalWeightedMeanPrecision(z, Ξ) + return MvNormalMeanCovariance(z, θ) end @rule MAR(:x, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin From 7a91826b99b8956a47f34482a41261abbfc854e2 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 10 Jan 2023 12:01:03 +0100 Subject: [PATCH 21/48] Update rules --- src/nodes/mv_autoregressive.jl | 2 +- src/rules/mv_autoregressive/marginals.jl | 2 +- src/rules/mv_autoregressive/x.jl | 16 +++++++++------- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index 3bf69a7ae..a71c4a7a8 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -132,7 +132,7 @@ end function mar_transition(order, Λ) dim = size(Λ, 1) - W = huge*diageye(dim*order) + W = 1.0*diageye(dim*order) W[1:dim, 1:dim] = Λ return W end diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl index c97c34861..51b7a1deb 100644 --- a/src/rules/mv_autoregressive/marginals.jl +++ b/src/rules/mv_autoregressive/marginals.jl @@ -44,7 +44,7 @@ function ar_y_x_marginal( # negate_inplace!(mW * mA) W_12 = -(mW * mA) - W_21 = (-mA' * mW) + W_21 = -(mA' * mW) W_22 = Ξ + mA' * mW * mA diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index 929f26d22..6625b64fd 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -11,20 +11,22 @@ begin mA = mar_companion_matrix(order, ds, ma) mW = mar_transition(getorder(meta), mΛ) - # this should be inside MARMeta es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) - Σ₁ = pinv(mA)*(Vy)*pinv(mA)' + pinv(mA'*mW*mA) - @show Σ₁ == pinv(mA)*(Vy + inv(mW))*pinv(mA)' - Σ₂ = inv(Λ) - θ = pinv(pinv(Σ₁) + inv(Σ₂)) - z = θ*pinv(Σ₁)*pinv(mA)*my + Σ₁ = Hermitian(pinv(mA)*(Vy)*pinv(mA') + pinv(mA'*mW*mA)) + # Σ₂ = inv(Λ) + # θ = Hermitian(pinv(pinv(Σ₁) + inv(Σ₂))) + # θ = Hermitian(pinv(Σ₁) + inv(Σ₂)) + θ = Hermitian(inv(Σ₁) + Λ) + # z = θ*pinv(Σ₁)*pinv(mA)*my + z = inv(Σ₁)*pinv(mA)*my - return MvNormalMeanCovariance(z, θ) + # return MvNormalMeanCovariance(z, θ) + return MvNormalWeightedMeanPrecision(z, θ) end @rule MAR(:x, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin From d5248e0cd6677fce67bcfcdcf4f38cfe0105eb25 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 10 Jan 2023 13:14:21 +0100 Subject: [PATCH 22/48] Fix FE --- src/nodes/mv_autoregressive.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index a71c4a7a8..fc49af2f7 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -45,24 +45,28 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl ma, Va = mean_cov(q_a) mA = mar_companion_matrix(order, ds, ma)[1:ds, 1:dim] + mx, Vx = ar_slice(F, myx, (dim+1):2dim), ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) my1, Vy1 = myx[1:ds], Vyx[1:ds, 1:ds] Vy1x = ar_slice(F, Vyx, 1:ds, dim+1:2dim) + # @show Vyx + # @show Vy1x + # this should be inside MARMeta es = [uvector(ds, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] g₁ = my1'*mΛ*my1 + tr(Vy1*mΛ) - g₂ = -mx'*mA'*mΛ*my1 + tr(Vy1x*mA'*mΛ) - g₃ = -g₂ + g₂ = mx'*mA'*mΛ*my1 + tr(Vy1x*mA'*mΛ) + g₃ = g₂ G = sum(sum(es[i]'*mΛ*es[j]*Fs[i]*(ma*ma' + Va)*Fs[j]' for i in 1:ds) for j in 1:ds) g₄ = mx'*G*mx + tr(Vx*G) - AE = n/2*log2π - 0.5*mean(logdet, q_Λ) + 0.5*(g₁ + g₂ + g₃ + g₄) + AE = n/2*log2π - 0.5*mean(logdet, q_Λ) + 0.5*(g₁ - g₂ - g₃ + g₄) if order > 1 - # AE += entropy(q_y_x) - idc = LazyArrays.Vcat(1:order, (dim+1):2dim) + AE += entropy(q_y_x) + idc = LazyArrays.Vcat(1:ds, dim+1:2dim) myx_n = view(myx, idc) Vyx_n = view(Vyx, idc, idc) q_y_x = MvNormalMeanCovariance(myx_n, Vyx_n) From 01f59f11fc314926a73ab56112c30105141b5852 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 10 Jan 2023 13:18:30 +0100 Subject: [PATCH 23/48] Clean up --- src/nodes/mv_autoregressive.jl | 2 +- src/rules/mv_autoregressive/a.jl | 3 +-- src/rules/mv_autoregressive/lambda.jl | 6 +----- src/rules/mv_autoregressive/x.jl | 9 ++------- 4 files changed, 5 insertions(+), 15 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index fc49af2f7..35cd47056 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -110,7 +110,7 @@ end AE = n/2*log2π - 0.5*mean(logdet, q_Λ) + 0.5*(g₁ + g₂ + g₃ + g₄) if order > 1 - # AE += entropy(q_y) + AE += entropy(q_y) q_y = MvNormalMeanCovariance(my1, Vy1) AE -= entropy(q_y) end diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index b12b25e31..b0ca27a69 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -19,10 +19,9 @@ es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] S = mar_shift(order, ds) - # NOTE: prove that sum(Fs[i]'*((mx*mx'+Vx')*S')*mW*es[i] for i in 1:ds) == 0.0 + # NOTE: prove that sum(Fs[i]'*((mx*mx'+Vx')*S')*mW*es[i] for i in 1:ds) == 0.0 D = sum(sum(es[i]'*mW*es[j]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) - # z = sum(Fs[i]'*((mx*mx'+Vx')*S' + mx*my'+Vyx')*mW*es[i] for i in 1:ds) z = sum(Fs[i]'*(mx*my'+Vyx')*mW*es[i] for i in 1:ds) return MvNormalWeightedMeanPrecision(z, D) diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index ad21ce759..73bd0bdec 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -7,9 +7,6 @@ F = Multivariate dim = order*ds - # n = div(ndims(q_y_x), 2) - - # y_x_mean, y_x_cov = mean_cov(q_y_x) ma, Va = mean_cov(q_a) mA = mar_companion_matrix(order, ds, ma) @@ -24,8 +21,7 @@ S = mar_shift(order, ds) G₁ = (my*my' + Vy)[1:ds, 1:ds] G₂ = ((my*mx' + Vyx)*mA')[1:ds, 1:ds] - # G₃ = transpose(G₂) - G₃ = (mA*(mx*my' + Vyx'))[1:ds, 1:ds] + G₃ = transpose(G₂) Ex_xx = mx*mx' + Vx G₅ = sum(sum(es[i]*ma'*Fs[i]'Ex_xx*Fs[j]*ma*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] G₆ = sum(sum(es[i]*tr(Fs[i]'*Ex_xx*Fs[j]*Va)*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index 6625b64fd..bdcd166e2 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -18,15 +18,10 @@ begin Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) Σ₁ = Hermitian(pinv(mA)*(Vy)*pinv(mA') + pinv(mA'*mW*mA)) - # Σ₂ = inv(Λ) - # θ = Hermitian(pinv(pinv(Σ₁) + inv(Σ₂))) - # θ = Hermitian(pinv(Σ₁) + inv(Σ₂)) - θ = Hermitian(inv(Σ₁) + Λ) - # z = θ*pinv(Σ₁)*pinv(mA)*my + Ξ = Hermitian(inv(Σ₁) + Λ) z = inv(Σ₁)*pinv(mA)*my - # return MvNormalMeanCovariance(z, θ) - return MvNormalWeightedMeanPrecision(z, θ) + return MvNormalWeightedMeanPrecision(z, Ξ) end @rule MAR(:x, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin From a839f3d4d646d67c6de6a24c4c54259e185e5ab9 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Wed, 11 Jan 2023 11:03:50 +0100 Subject: [PATCH 24/48] Update rules --- src/rules/mv_autoregressive/helpers.jl | 0 src/rules/mv_autoregressive/lambda.jl | 2 +- src/rules/mv_autoregressive/x.jl | 10 ++++++---- 3 files changed, 7 insertions(+), 5 deletions(-) delete mode 100644 src/rules/mv_autoregressive/helpers.jl diff --git a/src/rules/mv_autoregressive/helpers.jl b/src/rules/mv_autoregressive/helpers.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index 73bd0bdec..f73d30c05 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -58,5 +58,5 @@ end G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Ex_xx*Fs[j])*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] Δ = G₁ - G₂ - G₃ + G₅ + G₆ - return WishartMessage(length(my)+2, inv(Δ)) + return WishartMessage(ds+2, Δ) end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index bdcd166e2..20addd144 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -18,8 +18,9 @@ begin Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) Σ₁ = Hermitian(pinv(mA)*(Vy)*pinv(mA') + pinv(mA'*mW*mA)) - Ξ = Hermitian(inv(Σ₁) + Λ) - z = inv(Σ₁)*pinv(mA)*my + + Ξ = (pinv(Σ₁) + Λ) + z = pinv(Σ₁)*pinv(mA)*my return MvNormalWeightedMeanPrecision(z, Ξ) end @@ -42,9 +43,10 @@ end Fs = [mask_mar(order, ds, i) for i in 1:ds] Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) + Λ₀ = Hermitian(mA'*mW*mA) - Ξ = mA'*mW*mA + Λ - z = mA'*mW*my + Ξ = Λ₀ + Λ + z = Λ₀*pinv(mA)*my return MvNormalWeightedMeanPrecision(z, Ξ) end \ No newline at end of file From b52f77b7e021b0a530cd6d330aad6a97dd94a60d Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Wed, 11 Jan 2023 11:14:40 +0100 Subject: [PATCH 25/48] Update MF rules --- src/rules/mv_autoregressive/a.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index b0ca27a69..c8f2459c9 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -37,15 +37,18 @@ end my, Vy = mean_cov(q_y) mx, Vx = mean_cov(q_x) - mΛ = mean(q_Λ) + mΛ = mean(q_Λ) + mW = mar_transition(order, mΛ) + S = mar_shift(order, ds) # this should be inside MARMeta es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] + D = sum(sum(es[j]'*mW*es[i]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) - z = sum(Fs[i]'*((mx*mx'+Vx')*S' + mx*my)*mW*es[i] for i in 1:ds) + z = sum(Fs[i]'*((mx*mx'+Vx')*S' + mx*my')*mW*es[i] for i in 1:ds) return MvNormalWeightedMeanPrecision(z, D) end \ No newline at end of file From 87399c271b52491d9d6720bd78813c8066f489d5 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 24 Jan 2023 16:10:16 +0100 Subject: [PATCH 26/48] Modify variables structures for predictions functionality --- src/message.jl | 30 ++++++++++++++++++------------ src/variables/data.jl | 20 +++++++++++++++++++- src/variables/random.jl | 2 +- src/variables/variable.jl | 5 ++++- 4 files changed, 42 insertions(+), 15 deletions(-) diff --git a/src/message.jl b/src/message.jl index cb6a48ac7..64c85e7c6 100644 --- a/src/message.jl +++ b/src/message.jl @@ -323,18 +323,24 @@ function materialize!(mapping::MessageMapping, dependencies) # Message is initial if it is not clamped and all of the inputs are either clamped or initial is_message_initial = !is_message_clamped && (__check_all(is_clamped_or_initial, messages) && __check_all(is_clamped_or_initial, marginals)) - result, addons = rule( - message_mapping_fform(mapping), - mapping.vtag, - mapping.vconstraint, - mapping.msgs_names, - messages, - mapping.marginals_names, - marginals, - mapping.meta, - mapping.addons, - mapping.factornode - ) + result, addons = if !isnothing(messages) && any(ismissing, TupleTools.flatten(getdata.(messages))) + missing, mapping.addons + elseif !isnothing(marginals) && any(ismissing, TupleTools.flatten(getdata.(marginals))) + missing, mapping.addons + else + rule( + message_mapping_fform(mapping), + mapping.vtag, + mapping.vconstraint, + mapping.msgs_names, + messages, + mapping.marginals_names, + marginals, + mapping.meta, + mapping.addons, + mapping.factornode + ) + end # Inject extra addons after the rule has been executed addons = message_mapping_addons(mapping, getdata(messages), getdata(marginals), result, addons) diff --git a/src/variables/data.jl b/src/variables/data.jl index dac060e9e..3ab823c4f 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -5,6 +5,8 @@ import Base: show mutable struct DataVariable{D, S} <: AbstractVariable name :: Symbol collection_type :: AbstractVariableCollectionType + prediction :: MarginalObservable + input_messages :: Vector{MessageObservable{AbstractMessage}} messageout :: S nconnected :: Int end @@ -70,7 +72,7 @@ datavar(name::Symbol, ::Type{D}, dims::Tuple) where {D} datavar(name::Symbol, ::Type{D}, dims::Vararg{Int}) where {D} = datavar(DataVariableCreationOptions(D), name, D, dims) datavar(options::DataVariableCreationOptions{S}, name::Symbol, ::Type{D}, collection_type::AbstractVariableCollectionType = VariableIndividual()) where {S, D} = - DataVariable{D, S}(name, collection_type, options.subject, 0) + DataVariable{D, S}(name, collection_type, MarginalObservable(), Vector{MessageObservable{AbstractMessage}}(), options.subject, 0) function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, length::Int) where {D} return map(i -> datavar(similar(options), name, D, VariableVector(i)), 1:length) @@ -165,5 +167,21 @@ setanonymous!(::DataVariable, ::Bool) = nothing function setmessagein!(datavar::DataVariable, ::Int, messagein) datavar.nconnected += 1 + push!(datavar.input_messages, messagein) return nothing end + +marginal_prod_fn(datavar::DataVariable) = + marginal_prod_fn(FoldLeftProdStrategy(), ProdAnalytical(), UnspecifiedFormConstraint(), FormConstraintCheckLast()) + +_getprediction(datavar::DataVariable) = datavar.prediction +_setprediction!(datavar::DataVariable, observable) = connect!(_getprediction(datavar), observable) +_makeprediction(datavar::DataVariable) = collectLatest(AbstractMessage, Marginal, datavar.input_messages, marginal_prod_fn(datavar)) + +# options here must implement at least `Rocket.getscheduler` +function activate!(datavar::DataVariable, options) + + _setprediction!(datavar, _makeprediction(datavar)) + + return nothing +end \ No newline at end of file diff --git a/src/variables/random.jl b/src/variables/random.jl index 900addc00..919da8571 100644 --- a/src/variables/random.jl +++ b/src/variables/random.jl @@ -291,4 +291,4 @@ function initialize_output_messages!(chain::EqualityChain, randomvar::RandomVari randomvar.output_initialised = true return nothing -end +end \ No newline at end of file diff --git a/src/variables/variable.jl b/src/variables/variable.jl index 0dc48ea36..5d2be4800 100644 --- a/src/variables/variable.jl +++ b/src/variables/variable.jl @@ -1,7 +1,7 @@ export AbstractVariable, degree export is_clamped, is_marginalisation, is_moment_matching export FoldLeftProdStrategy, FoldRightProdStrategy, CustomProdStrategy -export getmarginal, getmarginals, setmarginal!, setmarginals!, name, as_variable +export getprediction, getpredictions, getmarginal, getmarginals, setmarginal!, setmarginals!, name, as_variable export setmessage!, setmessages! using Rocket @@ -80,6 +80,9 @@ add_pipeline_stage!(variable::AbstractVariable, stage) = error("Its not possible # Helper functions # Getters +getprediction(variable::AbstractVariable) = _getprediction(variable) +getpredictions(variables::AbstractArray{<:AbstractVariable}) = collectLatest(map(v -> getprediction(v), variables)) + getmarginal(variable::AbstractVariable) = getmarginal(variable, SkipInitial()) getmarginal(variable::AbstractVariable, skip_strategy::MarginalSkipStrategy) = apply_skip_filter(_getmarginal(variable), skip_strategy) From ae0e770af210fffa341e1ec48c83739c1ca49793 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 24 Jan 2023 17:22:56 +0100 Subject: [PATCH 27/48] Make format --- src/message.jl | 4 ++-- src/variables/data.jl | 6 ++---- src/variables/random.jl | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/message.jl b/src/message.jl index 64c85e7c6..2d914180c 100644 --- a/src/message.jl +++ b/src/message.jl @@ -323,11 +323,11 @@ function materialize!(mapping::MessageMapping, dependencies) # Message is initial if it is not clamped and all of the inputs are either clamped or initial is_message_initial = !is_message_clamped && (__check_all(is_clamped_or_initial, messages) && __check_all(is_clamped_or_initial, marginals)) - result, addons = if !isnothing(messages) && any(ismissing, TupleTools.flatten(getdata.(messages))) + result, addons = if !isnothing(messages) && any(ismissing, TupleTools.flatten(getdata.(messages))) missing, mapping.addons elseif !isnothing(marginals) && any(ismissing, TupleTools.flatten(getdata.(marginals))) missing, mapping.addons - else + else rule( message_mapping_fform(mapping), mapping.vtag, diff --git a/src/variables/data.jl b/src/variables/data.jl index 3ab823c4f..dadf85672 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -171,8 +171,7 @@ function setmessagein!(datavar::DataVariable, ::Int, messagein) return nothing end -marginal_prod_fn(datavar::DataVariable) = - marginal_prod_fn(FoldLeftProdStrategy(), ProdAnalytical(), UnspecifiedFormConstraint(), FormConstraintCheckLast()) +marginal_prod_fn(datavar::DataVariable) = marginal_prod_fn(FoldLeftProdStrategy(), ProdAnalytical(), UnspecifiedFormConstraint(), FormConstraintCheckLast()) _getprediction(datavar::DataVariable) = datavar.prediction _setprediction!(datavar::DataVariable, observable) = connect!(_getprediction(datavar), observable) @@ -180,8 +179,7 @@ _makeprediction(datavar::DataVariable) = collectLatest(AbstractMessa # options here must implement at least `Rocket.getscheduler` function activate!(datavar::DataVariable, options) - _setprediction!(datavar, _makeprediction(datavar)) return nothing -end \ No newline at end of file +end diff --git a/src/variables/random.jl b/src/variables/random.jl index 919da8571..900addc00 100644 --- a/src/variables/random.jl +++ b/src/variables/random.jl @@ -291,4 +291,4 @@ function initialize_output_messages!(chain::EqualityChain, randomvar::RandomVari randomvar.output_initialised = true return nothing -end \ No newline at end of file +end From 696f8ea4685fb5965873119fe84be13f6292c7ee Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Mon, 30 Jan 2023 18:14:42 +0100 Subject: [PATCH 28/48] WIP: Change data --- src/variables/data.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/variables/data.jl b/src/variables/data.jl index dadf85672..91aaba208 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -9,6 +9,7 @@ mutable struct DataVariable{D, S} <: AbstractVariable input_messages :: Vector{MessageObservable{AbstractMessage}} messageout :: S nconnected :: Int + # allow_missing :: Bool end Base.show(io::IO, datavar::DataVariable) = print(io, "DataVariable(", indexed_name(datavar), ")") @@ -17,6 +18,11 @@ struct DataVariableCreationOptions{S} subject::S end +allows_missing(datavar::DataVariable) = allows_missing(datavar, eltype(datavar.messageout)) + +allows_missing(datavar::DataVariable, ::Type) = true +allows_missing(datavar::DataVariable, ::Type{<:Message}) = false + Base.similar(options::DataVariableCreationOptions) = DataVariableCreationOptions(similar(options.subject)) DataVariableCreationOptions(::Type{D}) where {D} = DataVariableCreationOptions(D, nothing) From 22ce46e30768bb98ca83b34eac46dbe2d67b51a0 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 30 Jan 2023 18:56:28 +0100 Subject: [PATCH 29/48] feat: add allows_missings function & tests --- src/variables/data.jl | 15 +++--- test/variables/test_data.jl | 97 ++++++++++++++++++++++--------------- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/src/variables/data.jl b/src/variables/data.jl index 91aaba208..ccc0b16a8 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -9,7 +9,6 @@ mutable struct DataVariable{D, S} <: AbstractVariable input_messages :: Vector{MessageObservable{AbstractMessage}} messageout :: S nconnected :: Int - # allow_missing :: Bool end Base.show(io::IO, datavar::DataVariable) = print(io, "DataVariable(", indexed_name(datavar), ")") @@ -18,11 +17,6 @@ struct DataVariableCreationOptions{S} subject::S end -allows_missing(datavar::DataVariable) = allows_missing(datavar, eltype(datavar.messageout)) - -allows_missing(datavar::DataVariable, ::Type) = true -allows_missing(datavar::DataVariable, ::Type{<:Message}) = false - Base.similar(options::DataVariableCreationOptions) = DataVariableCreationOptions(similar(options.subject)) DataVariableCreationOptions(::Type{D}) where {D} = DataVariableCreationOptions(D, nothing) @@ -84,6 +78,10 @@ function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, return map(i -> datavar(similar(options), name, D, VariableVector(i)), 1:length) end +function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, dim1::Int, extra_dims::Vararg{Int}) where {D} + return datavar(options, name, D, (dim1, extra_dims...)) +end + function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, dims::Tuple) where {D} indices = CartesianIndices(dims) size = axes(indices) @@ -109,6 +107,11 @@ isdata(::AbstractArray{<:DataVariable}) = true isconst(::DataVariable) = false isconst(::AbstractArray{<:DataVariable}) = false +allows_missings(datavar::DataVariable) = allows_missings(datavar, eltype(datavar.messageout)) + +allows_missings(datavar::DataVariable, ::Type{ Message{D} }) where {D} = false +allows_missings(datavar::DataVariable, ::Type{ Union{Message{Missing}, Message{D}}}) where {D} = true + function Base.getindex(datavar::DataVariable, i...) error("Variable $(indexed_name(datavar)) has been indexed with `[$(join(i, ','))]`. Direct indexing of `data` variables is not allowed.") end diff --git a/test/variables/test_data.jl b/test/variables/test_data.jl index 5c08463f9..776a2996e 100644 --- a/test/variables/test_data.jl +++ b/test/variables/test_data.jl @@ -4,23 +4,30 @@ using Test using ReactiveMP using Rocket +import ReactiveMP: DataVariableCreationOptions import ReactiveMP: collection_type, VariableIndividual, VariableVector, VariableArray, linear_index import ReactiveMP: getconst, proxy_variables -import ReactiveMP: israndom, isproxy +import ReactiveMP: israndom, isproxy, allows_missings @testset "DataVariable" begin @testset "Simple creation" begin + randomize_update(::Type{Missing}, size) = fill(missing, size) randomize_update(::Type{T}, size) where {T <: Union{Int, Float64}} = rand(T, size) randomize_update(::Type{V}, size) where {V <: AbstractVector} = map(_ -> rand(eltype(V), 1), CartesianIndices(size)) function test_updates(vs, type, size) nupdates = 3 updates = [] - subscription = subscribe!(getmarginals(vs), (update) -> push!(updates, ReactiveMP.getdata.(update))) + subscription = subscribe!(getmarginals(vs), (update) -> begin + update_data = ReactiveMP.getdata.(update) + if all(element -> element isa type, update_data) + push!(updates, update_data) + end + end) for _ in 1:nupdates update = randomize_update(type, size) update!(vs, update) - @test last(updates) == update + @test all(last(updates) .=== update) end @test length(updates) === nupdates unsubscribe!(subscription) @@ -30,47 +37,61 @@ import ReactiveMP: israndom, isproxy return true end - for sym in (:x, :y, :z), type in (Float64, Int64, Vector{Float64}) - v = datavar(sym, type) + for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), allow_missings in (true, false) + options = DataVariableCreationOptions(T, nothing, Val(allow_missings)) + variable = datavar(options, sym, T) - @test !israndom(v) - @test eltype(v) === type - @test name(v) === sym - @test collection_type(v) isa VariableIndividual - @test proxy_variables(v) === nothing - @test !isproxy(v) + @test !israndom(variable) + @test eltype(variable) === T + @test name(variable) === sym + @test collection_type(variable) isa VariableIndividual + @test proxy_variables(variable) === nothing + @test !isproxy(variable) + @test allows_missings(variable) === allow_missings end - for sym in (:x, :y, :z), type in (Float64, Int64, Vector{Float64}), n in (10, 20) - vs = datavar(sym, type, n) + for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), n in (10, 20), allow_missings in (true, false) + options = DataVariableCreationOptions(T, nothing, Val(allow_missings)) + variables = datavar(options, sym, T, n) - @test !israndom(vs) - @test length(vs) === n - @test vs isa Vector - @test all(v -> !israndom(v), vs) - @test all(v -> name(v) === sym, vs) - @test all(v -> collection_type(v) isa VariableVector, vs) - @test all(t -> linear_index(collection_type(t[2])) === t[1], enumerate(vs)) - @test all(v -> eltype(v) === type, vs) - @test !isproxy(vs) - @test all(v -> !isproxy(v), vs) - @test test_updates(vs, type, (n,)) + @test !israndom(variables) + @test length(variables) === n + @test variables isa Vector + @test all(v -> !israndom(v), variables) + @test all(v -> name(v) === sym, variables) + @test all(v -> collection_type(v) isa VariableVector, variables) + @test all(t -> linear_index(collection_type(t[2])) === t[1], enumerate(variables)) + @test all(v -> eltype(v) === T, variables) + @test !isproxy(variables) + @test all(v -> !isproxy(v), variables) + @test test_updates(variables, T, (n,)) + + @test all(v -> allows_missings(v) === allow_missings, variables) + if allow_missings + test_updates(variables, Missing, (n, )) + end end - for sym in (:x, :y, :z), type in (Float64, Int64, Vector{Float64}), l in (10, 20), r in (10, 20) - for vs in (datavar(sym, type, l, r), datavar(sym, type, (l, r))) - @test !israndom(vs) - @test size(vs) === (l, r) - @test length(vs) === l * r - @test vs isa Matrix - @test all(v -> !israndom(v), vs) - @test all(v -> name(v) === sym, vs) - @test all(v -> collection_type(v) isa VariableArray, vs) - @test all(t -> linear_index(collection_type(t[2])) === t[1], enumerate(vs)) - @test all(v -> eltype(v) === type, vs) - @test !isproxy(vs) - @test all(v -> !isproxy(v), vs) - @test test_updates(vs, type, (l, r)) + for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), l in (10, 20), r in (10, 20), allow_missings in (true, false) + options = DataVariableCreationOptions(T, nothing, Val(allow_missings)) + for variables in (datavar(options, sym, T, l, r), datavar(options, sym, T, (l, r))) + @test !israndom(variables) + @test size(variables) === (l, r) + @test length(variables) === l * r + @test variables isa Matrix + @test all(v -> !israndom(v), variables) + @test all(v -> name(v) === sym, variables) + @test all(v -> collection_type(v) isa VariableArray, variables) + @test all(t -> linear_index(collection_type(t[2])) === t[1], enumerate(variables)) + @test all(v -> eltype(v) === T, variables) + @test !isproxy(variables) + @test all(v -> !isproxy(v), variables) + @test test_updates(variables, T, (l, r)) + + @test all(v -> allows_missings(v) === allow_missings, variables) + if allow_missings + test_updates(variables, Missing, (l, r)) + end end end end From 54074d3df647a38174136d25ad6c60fd90c8513e Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Wed, 1 Feb 2023 13:14:22 +0100 Subject: [PATCH 30/48] Make format --- src/variables/data.jl | 6 +++--- test/variables/test_data.jl | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/variables/data.jl b/src/variables/data.jl index ccc0b16a8..90b66a12e 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -108,9 +108,9 @@ isconst(::DataVariable) = false isconst(::AbstractArray{<:DataVariable}) = false allows_missings(datavar::DataVariable) = allows_missings(datavar, eltype(datavar.messageout)) - -allows_missings(datavar::DataVariable, ::Type{ Message{D} }) where {D} = false -allows_missings(datavar::DataVariable, ::Type{ Union{Message{Missing}, Message{D}}}) where {D} = true +allows_missings(datavars::AbstractArray{<:DataVariable}) = all(allows_missings, datavars) +allows_missings(datavar::DataVariable, ::Type{Message{D}}) where {D} = false +allows_missings(datavar::DataVariable, ::Type{Union{Message{Missing}, Message{D}}}) where {D} = true function Base.getindex(datavar::DataVariable, i...) error("Variable $(indexed_name(datavar)) has been indexed with `[$(join(i, ','))]`. Direct indexing of `data` variables is not allowed.") diff --git a/test/variables/test_data.jl b/test/variables/test_data.jl index 776a2996e..683cf1a1d 100644 --- a/test/variables/test_data.jl +++ b/test/variables/test_data.jl @@ -11,14 +11,14 @@ import ReactiveMP: israndom, isproxy, allows_missings @testset "DataVariable" begin @testset "Simple creation" begin - randomize_update(::Type{Missing}, size) = fill(missing, size) + randomize_update(::Type{Missing}, size) = fill(missing, size) randomize_update(::Type{T}, size) where {T <: Union{Int, Float64}} = rand(T, size) randomize_update(::Type{V}, size) where {V <: AbstractVector} = map(_ -> rand(eltype(V), 1), CartesianIndices(size)) function test_updates(vs, type, size) nupdates = 3 updates = [] - subscription = subscribe!(getmarginals(vs), (update) -> begin + subscription = subscribe!(getmarginals(vs), (update) -> begin update_data = ReactiveMP.getdata.(update) if all(element -> element isa type, update_data) push!(updates, update_data) @@ -47,7 +47,7 @@ import ReactiveMP: israndom, isproxy, allows_missings @test collection_type(variable) isa VariableIndividual @test proxy_variables(variable) === nothing @test !isproxy(variable) - @test allows_missings(variable) === allow_missings + @test allows_missings(variable) === allow_missings end for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), n in (10, 20), allow_missings in (true, false) @@ -68,7 +68,7 @@ import ReactiveMP: israndom, isproxy, allows_missings @test all(v -> allows_missings(v) === allow_missings, variables) if allow_missings - test_updates(variables, Missing, (n, )) + test_updates(variables, Missing, (n,)) end end From d9564ed260eef6e934a1a86abcc50bdae020df77 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Wed, 1 Feb 2023 16:15:22 +0100 Subject: [PATCH 31/48] improve factorisation logic for prediction variables --- .../specifications/factorisation.jl | 68 +++++++++++++++---- src/variables/data.jl | 3 +- 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/src/constraints/specifications/factorisation.jl b/src/constraints/specifications/factorisation.jl index 6ccb18e7b..66cfbc7a1 100644 --- a/src/constraints/specifications/factorisation.jl +++ b/src/constraints/specifications/factorisation.jl @@ -187,19 +187,56 @@ resolve_factorisation(::UnspecifiedConstraints, any, allvariables, fform, variab # Preoptimised dispatch rule for unspecified constraints and a deterministic node with any number of inputs resolve_factorisation(::UnspecifiedConstraints, ::Deterministic, allvariables, fform, variables) = FullFactorisation() -# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 2 inputs -resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: RandomVariable} = ((1, 2),) -resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: RandomVariable} = ((1,), (2,)) -resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: Union{<:ConstVariable, <:DataVariable}} = ((1,), (2,)) - -# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs -resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: RandomVariable} = ((1, 2, 3),) -resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: RandomVariable, V3 <: RandomVariable} = ((1,), (2, 3)) -resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: Union{<:ConstVariable, <:DataVariable}, V3 <: RandomVariable} = ((1, 3), (2,)) -resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: Union{<:ConstVariable, <:DataVariable}} = ((1, 2), (3,)) -resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: Union{<:ConstVariable, <:DataVariable}, V3 <: Union{<:ConstVariable, <:DataVariable}} = ((1,), (2,), (3,)) -resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: RandomVariable, V3 <: Union{<:ConstVariable, <:DataVariable}} = ((1,), (2,), (3,)) -resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: Union{<:ConstVariable, <:DataVariable}, V3 <: RandomVariable} = ((1,), (2,), (3,)) +# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 2 inputs, random variable & constant variable +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: RandomVariable} = ((1, 2),) +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: ConstVariable, V2 <: RandomVariable} = ((1,), (2,)) +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: ConstVariable} = ((1,), (2,)) + +# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 2 inputs, random variable & data variable +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2}) where {V1 <: DataVariable, V2 <: RandomVariable} = + allows_missings(variables[1]) ? ((1, 2),) : ((1,), (2,)) +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: DataVariable} = + allows_missings(variables[2]) ? ((1, 2),) : ((1,), (2,)) + +# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs, random variable & constant variables +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: RandomVariable} = ((1, 2, 3),) +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: ConstVariable, V2 <: RandomVariable, V3 <: RandomVariable} = ((1,), (2, 3)) +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: ConstVariable, V3 <: RandomVariable} = ((1, 3), (2,)) +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: ConstVariable} = ((1, 2), (3,)) +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: ConstVariable, V3 <: ConstVariable} = ((1,), (2,), (3,)) +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: ConstVariable, V2 <: RandomVariable, V3 <: ConstVariable} = ((1,), (2,), (3,)) +resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: ConstVariable, V2 <: ConstVariable, V3 <: RandomVariable} = ((1,), (2,), (3,)) + +# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs, random variable & data variable +resolve_factorisation( + ::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3} +) where {V1 <: DataVariable, V2 <: RandomVariable, V3 <: RandomVariable} = allows_missings(variables[1]) ? ((1, 2, 3),) : ((1,), (2, 3)) +resolve_factorisation( + ::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3} +) where {V1 <: RandomVariable, V2 <: DataVariable, V3 <: RandomVariable} = allows_missings(variables[2]) ? ((1, 2, 3),) : ((1, 3), (2,)) +resolve_factorisation( + ::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3} +) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: DataVariable} = allows_missings(variables[3]) ? ((1, 2, 3),) : ((1, 2), (3,)) + +# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs, random variable & data variable & const variable +resolve_factorisation( + ::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3} +) where {V1 <: DataVariable, V2 <: ConstVariable, V3 <: RandomVariable} = allows_missings(variables[1]) ? ((1, 3), (2,)) : ((1,), (2,), (3,)) +resolve_factorisation( + ::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3} +) where {V1 <: DataVariable, V2 <: RandomVariable, V3 <: ConstVariable} = allows_missings(variables[1]) ? ((1, 2), (3,)) : ((1,), (2,), (3,)) +resolve_factorisation( + ::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3} +) where {V1 <: ConstVariable, V2 <: DataVariable, V3 <: RandomVariable} = allows_missings(variables[2]) ? ((1,), (2, 3)) : ((1,), (3,), (2,)) +resolve_factorisation( + ::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3} +) where {V1 <: RandomVariable, V2 <: DataVariable, V3 <: ConstVariable} = allows_missings(variables[2]) ? ((1, 2), (3,)) : ((1,), (2,), (3,)) +resolve_factorisation( + ::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3} +) where {V1 <: ConstVariable, V2 <: RandomVariable, V3 <: DataVariable} = allows_missings(variables[3]) ? ((1,), (2, 3)) : ((1,), (2,), (3,)) +resolve_factorisation( + ::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3} +) where {V1 <: RandomVariable, V2 <: ConstVariable, V3 <: DataVariable} = allows_missings(variables[3]) ? ((1, 3), (2,)) : ((1,), (2,), (3,)) """ resolve_factorisation(constraints, allvariables, fform, variables) @@ -419,8 +456,11 @@ function resolve_factorisation(::Stochastic, constraints, allvariables, fform, _ index::Int = 1 shift::Int = 0 for varref in var_refs - if israndom(varref[3]) + if israndom(varref[3]) || (isdata(varref[3]) && allows_missings(varref[3])) # We process everything as usual if varref is a random variable + # or if the variable is data variable and it allows missing + # We probably should change the logic from "allows missings" to "used as prediction" + # For now we assume that if data variable allows missing input it is indeed "used as prediction" __process_factorisation_entry!(varref[1], varref[2], shift) else # We filter out varref from all clusters if it is not random diff --git a/src/variables/data.jl b/src/variables/data.jl index 90b66a12e..beb054103 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -108,9 +108,10 @@ isconst(::DataVariable) = false isconst(::AbstractArray{<:DataVariable}) = false allows_missings(datavar::DataVariable) = allows_missings(datavar, eltype(datavar.messageout)) + allows_missings(datavars::AbstractArray{<:DataVariable}) = all(allows_missings, datavars) allows_missings(datavar::DataVariable, ::Type{Message{D}}) where {D} = false -allows_missings(datavar::DataVariable, ::Type{Union{Message{Missing}, Message{D}}}) where {D} = true +allows_missings(datavar::DataVariable, ::Type{Union{Message{Missing}, Message{D}}} where {D}) = true function Base.getindex(datavar::DataVariable, i...) error("Variable $(indexed_name(datavar)) has been indexed with `[$(join(i, ','))]`. Direct indexing of `data` variables is not allowed.") From e936392ec16beea9ac48437e8e4cc7ab3c6f3aad Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Wed, 1 Feb 2023 21:10:11 +0100 Subject: [PATCH 32/48] fix: update warning for factorisation check --- src/node.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/node.jl b/src/node.jl index 6cde48ac6..56b5e5e19 100644 --- a/src/node.jl +++ b/src/node.jl @@ -1144,7 +1144,7 @@ macro node(fformtype, sdtype, interfaces_list) missingclustererr = "Cannot find the cluster for the variable connected to the `$(name)` interface around the `$fformtype` node." quote # If a variable `$name` is a constvar or a datavar - if ReactiveMP.isconst($(name)) || ReactiveMP.isdata($(name)) + if ReactiveMP.isconst($(name)) || (ReactiveMP.isdata($(name)) && !ReactiveMP.allows_missings($(name))) local __factorisation = ReactiveMP.factorisation(node) # Find the factorization cluster associated with the constvar `$name` local __index = ReactiveMP.interface_get_index(Val{$(QuoteNode(fbottomtype))}, Val{$(QuoteNode(name))}) From 07c5bdce21008ba8f61ee6d12edea5c7dac995d9 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Mon, 6 Feb 2023 14:44:58 +0100 Subject: [PATCH 33/48] Update mapping for marginal --- src/marginal.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/marginal.jl b/src/marginal.jl index 72d2586e6..7942d55b9 100644 --- a/src/marginal.jl +++ b/src/marginal.jl @@ -190,8 +190,14 @@ function (mapping::MarginalMapping)(dependencies) # Marginal is initial if it is not clamped and all of the inputs are either clamped or initial is_marginal_initial = !is_marginal_clamped && (__check_all(is_clamped_or_initial, messages) && __check_all(is_clamped_or_initial, marginals)) - marginal = marginalrule(marginal_mapping_fform(mapping), mapping.vtag, mapping.msgs_names, messages, mapping.marginals_names, marginals, mapping.meta, mapping.factornode) - + marginal = if !isnothing(messages) && any(ismissing, TupleTools.flatten(getdata.(messages))) + missing + elseif !isnothing(marginals) && any(ismissing, TupleTools.flatten(getdata.(marginals))) + missing + else + marginalrule(marginal_mapping_fform(mapping), mapping.vtag, mapping.msgs_names, messages, mapping.marginals_names, marginals, mapping.meta, mapping.factornode) + end + return Marginal(marginal, is_marginal_clamped, is_marginal_initial, nothing) end From 37d0a72fb89c77597757da56ac855e12a9ee8f9e Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 7 Feb 2023 10:08:13 +0100 Subject: [PATCH 34/48] Make format --- src/marginal.jl | 2 +- wip_mar.jl | 124 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 wip_mar.jl diff --git a/src/marginal.jl b/src/marginal.jl index 7942d55b9..72cd88193 100644 --- a/src/marginal.jl +++ b/src/marginal.jl @@ -197,7 +197,7 @@ function (mapping::MarginalMapping)(dependencies) else marginalrule(marginal_mapping_fform(mapping), mapping.vtag, mapping.msgs_names, messages, mapping.marginals_names, marginals, mapping.meta, mapping.factornode) end - + return Marginal(marginal, is_marginal_clamped, is_marginal_initial, nothing) end diff --git a/wip_mar.jl b/wip_mar.jl new file mode 100644 index 000000000..28422fda1 --- /dev/null +++ b/wip_mar.jl @@ -0,0 +1,124 @@ +using RxInfer +using LinearAlgebra +using Plots +using Random + +Random.seed!(42) + +function generate_mar(order, ds, n_samples) + As = [0.1*randn(ds,ds) for _ in 1:order] + + x = [randn(ds) for _ in 1:order] + y = deepcopy(x) + + for _ in 1:n_samples + m = mapreduce(x -> x[1] * x[2], +, zip(As, x[end:-1:end-order+1])) + dist = MvNormal(m, diageye(ds)) + push!(x, rand(dist)) + push!(y, rand(MvNormal(x[end], diageye(ds)))) + end + collect(Iterators.flatten(hcat(As))), x, y +end + +@model function multivariateAR(n_samples, order, dimension) + + o = datavar(Vector{Float64}, n_samples) + y = randomvar(n_samples) + + a ~ MvNormalMeanCovariance(randn(dimension^2*order), diageye(dimension^2*order)) + + # NOTE: Wishart is naughty + Λ ~ Wishart(dimension, diageye(dimension)) + + B = zeros(dimension, dimension*order); B[1:dimension, 1:dimension] = diageye(dimension) + + x ~ MvNormalMeanCovariance(zeros(dimension*order), diageye(dimension*order)) + + x_prev = x + for i in 1:n_samples + y[i] ~ MAR(x_prev, a, Λ) where {meta = MARMeta(order, dimension)} + + o[i] ~ MvNormalMeanCovariance(B*y[i], diageye(dimension)) + x_prev = y[i] + end +end + +constraints = @constraints begin + q(y, x, a, Λ) = q(y, x)q(a)q(Λ) +end + +# constraints = @constraints begin +# q(y, x, a, Λ, τ) = q(y, x)q(a)q(Λ)q(τ) +# end + +# mf_constraints = @constraints begin +# q(y, x, a, Λ) = q(y)q(x)q(a)q(Λ) +# end + +n = 100 + +d = 4 +p = 5 +coefs, lat, obs = generate_mar(p, d, n) + +transform_obs = [vcat(obs[i], obs[i-1]) for i in 2:n+1] +corrected_obs = [x[1:d] for x in transform_obs] + +mdata = (o = corrected_obs, ) +minitmarginals = (Λ = Wishart(d, diageye(d)), a = MvNormalMeanPrecision(zeros(d^2*p), diageye(d^2*p))) +minitmarginals = (Λ = Wishart(d, diageye(d)), a = MvNormalMeanPrecision(zeros(d^2*p), diageye(d^2*p)), x = vague(MvNormalMeanCovariance, d*p), y = vague(MvNormalMeanCovariance, d*p)) + +# minitmarginals = (Λ = Wishart(d+1, diageye(d)), a = MvNormalMeanPrecision(zeros(d^2*p), diageye(d^2*p)), τ = GammaShapeRate(1.0, 1.0)) + +# First execution is slow due to Julia's initial compilation +mresult = inference( + model = multivariateAR(n, p, d), + data = mdata, + # constraints = mf_constraints, + constraints = constraints, + initmarginals = minitmarginals, + free_energy = true, + iterations = 25, + showprogress = true, +) + + +scatter(coefs) +plot!(mean(mresult.posteriors[:a][end]), ribbon=sqrt.(var(mresult.posteriors[:a][end]))) +vline!([i*d*d for i in 1:p]) + +plot(mresult.free_energy[1:end]) + +lat_mean = mean.(mresult.posteriors[:y][end]) +lat_var = var.(mresult.posteriors[:y][end]) + +lat_mean₁ = first.(lat_mean) +lat_mean₂ = getindex.(lat_mean, 2) +lat_mean₃ = getindex.(lat_mean, 3) +lat_mean₄ = getindex.(lat_mean, 4) + +lat_var₁ = first.(lat_var) +lat_var₂ = getindex.(lat_var, 2) +lat_var₃ = getindex.(lat_var, 3) +lat_var₄ = getindex.(lat_var, 4) + +Λ_inf = mresult.posteriors[:Λ][end] +@show mean(Λ_inf) + +scatter(first.(mdata[:o])) +plot!(first.(lat[2:end])) +plot!(lat_mean₁, ribbon=sqrt.(lat_var₁)) + +scatter(getindex.(mdata[:o], 2)) +plot!(getindex.(lat[2:end], 2)) +plot!(lat_mean₂, ribbon=sqrt.(lat_var₂)) + +scatter(getindex.(mdata[:o], 3)) +plot!(getindex.(lat[2:end], 3)) +plot!(lat_mean₃, ribbon=sqrt.(lat_var₃)) + +scatter(getindex.(mdata[:o], 4)) +plot!(getindex.(lat[2:end], 4)) +plot!(lat_mean₄, ribbon=sqrt.(lat_var₄)) + + From 2bb563fec84c6569480e6bb89cf732f236918724 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 7 Feb 2023 10:10:56 +0100 Subject: [PATCH 35/48] Delete WIPs --- wip_mar.jl | 124 ----------------------------------------------------- 1 file changed, 124 deletions(-) delete mode 100644 wip_mar.jl diff --git a/wip_mar.jl b/wip_mar.jl deleted file mode 100644 index 28422fda1..000000000 --- a/wip_mar.jl +++ /dev/null @@ -1,124 +0,0 @@ -using RxInfer -using LinearAlgebra -using Plots -using Random - -Random.seed!(42) - -function generate_mar(order, ds, n_samples) - As = [0.1*randn(ds,ds) for _ in 1:order] - - x = [randn(ds) for _ in 1:order] - y = deepcopy(x) - - for _ in 1:n_samples - m = mapreduce(x -> x[1] * x[2], +, zip(As, x[end:-1:end-order+1])) - dist = MvNormal(m, diageye(ds)) - push!(x, rand(dist)) - push!(y, rand(MvNormal(x[end], diageye(ds)))) - end - collect(Iterators.flatten(hcat(As))), x, y -end - -@model function multivariateAR(n_samples, order, dimension) - - o = datavar(Vector{Float64}, n_samples) - y = randomvar(n_samples) - - a ~ MvNormalMeanCovariance(randn(dimension^2*order), diageye(dimension^2*order)) - - # NOTE: Wishart is naughty - Λ ~ Wishart(dimension, diageye(dimension)) - - B = zeros(dimension, dimension*order); B[1:dimension, 1:dimension] = diageye(dimension) - - x ~ MvNormalMeanCovariance(zeros(dimension*order), diageye(dimension*order)) - - x_prev = x - for i in 1:n_samples - y[i] ~ MAR(x_prev, a, Λ) where {meta = MARMeta(order, dimension)} - - o[i] ~ MvNormalMeanCovariance(B*y[i], diageye(dimension)) - x_prev = y[i] - end -end - -constraints = @constraints begin - q(y, x, a, Λ) = q(y, x)q(a)q(Λ) -end - -# constraints = @constraints begin -# q(y, x, a, Λ, τ) = q(y, x)q(a)q(Λ)q(τ) -# end - -# mf_constraints = @constraints begin -# q(y, x, a, Λ) = q(y)q(x)q(a)q(Λ) -# end - -n = 100 - -d = 4 -p = 5 -coefs, lat, obs = generate_mar(p, d, n) - -transform_obs = [vcat(obs[i], obs[i-1]) for i in 2:n+1] -corrected_obs = [x[1:d] for x in transform_obs] - -mdata = (o = corrected_obs, ) -minitmarginals = (Λ = Wishart(d, diageye(d)), a = MvNormalMeanPrecision(zeros(d^2*p), diageye(d^2*p))) -minitmarginals = (Λ = Wishart(d, diageye(d)), a = MvNormalMeanPrecision(zeros(d^2*p), diageye(d^2*p)), x = vague(MvNormalMeanCovariance, d*p), y = vague(MvNormalMeanCovariance, d*p)) - -# minitmarginals = (Λ = Wishart(d+1, diageye(d)), a = MvNormalMeanPrecision(zeros(d^2*p), diageye(d^2*p)), τ = GammaShapeRate(1.0, 1.0)) - -# First execution is slow due to Julia's initial compilation -mresult = inference( - model = multivariateAR(n, p, d), - data = mdata, - # constraints = mf_constraints, - constraints = constraints, - initmarginals = minitmarginals, - free_energy = true, - iterations = 25, - showprogress = true, -) - - -scatter(coefs) -plot!(mean(mresult.posteriors[:a][end]), ribbon=sqrt.(var(mresult.posteriors[:a][end]))) -vline!([i*d*d for i in 1:p]) - -plot(mresult.free_energy[1:end]) - -lat_mean = mean.(mresult.posteriors[:y][end]) -lat_var = var.(mresult.posteriors[:y][end]) - -lat_mean₁ = first.(lat_mean) -lat_mean₂ = getindex.(lat_mean, 2) -lat_mean₃ = getindex.(lat_mean, 3) -lat_mean₄ = getindex.(lat_mean, 4) - -lat_var₁ = first.(lat_var) -lat_var₂ = getindex.(lat_var, 2) -lat_var₃ = getindex.(lat_var, 3) -lat_var₄ = getindex.(lat_var, 4) - -Λ_inf = mresult.posteriors[:Λ][end] -@show mean(Λ_inf) - -scatter(first.(mdata[:o])) -plot!(first.(lat[2:end])) -plot!(lat_mean₁, ribbon=sqrt.(lat_var₁)) - -scatter(getindex.(mdata[:o], 2)) -plot!(getindex.(lat[2:end], 2)) -plot!(lat_mean₂, ribbon=sqrt.(lat_var₂)) - -scatter(getindex.(mdata[:o], 3)) -plot!(getindex.(lat[2:end], 3)) -plot!(lat_mean₃, ribbon=sqrt.(lat_var₃)) - -scatter(getindex.(mdata[:o], 4)) -plot!(getindex.(lat[2:end], 4)) -plot!(lat_mean₄, ribbon=sqrt.(lat_var₄)) - - From 65cdae88f2ef0c9db642dff520f527caea0c1158 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 7 Feb 2023 10:25:48 +0100 Subject: [PATCH 36/48] Make format --- src/marginal.jl | 2 +- src/nodes/mv_autoregressive.jl | 78 +++++++++------------- src/rules/mv_autoregressive/a.jl | 34 +++++----- src/rules/mv_autoregressive/lambda.jl | 85 +++++++++++------------- src/rules/mv_autoregressive/marginals.jl | 21 ++---- src/rules/mv_autoregressive/x.jl | 24 +++---- src/rules/mv_autoregressive/y.jl | 17 ++--- 7 files changed, 111 insertions(+), 150 deletions(-) diff --git a/src/marginal.jl b/src/marginal.jl index 7942d55b9..72cd88193 100644 --- a/src/marginal.jl +++ b/src/marginal.jl @@ -197,7 +197,7 @@ function (mapping::MarginalMapping)(dependencies) else marginalrule(marginal_mapping_fform(mapping), mapping.vtag, mapping.msgs_names, messages, mapping.marginals_names, marginals, mapping.meta, mapping.factornode) end - + return Marginal(marginal, is_marginal_clamped, is_marginal_initial, nothing) end diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index 35cd47056..1cc739b52 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -11,7 +11,7 @@ struct MARMeta order :: Int # order (lag) of MAR ds :: Int # dimensionality of MAR process, i.e., the number of correlated AR processes - function MARMeta(order, ds=2) + function MARMeta(order, ds = 2) if ds < 2 @error "ds parameter should be > 1. Use AR node if ds = 1" end @@ -19,36 +19,29 @@ struct MARMeta end end - -getorder(meta::MARMeta) = meta.order -getdimensionality(meta::MARMeta) = meta.ds +getorder(meta::MARMeta) = meta.order +getdimensionality(meta::MARMeta) = meta.ds @node MAR Stochastic [y, x, a, Λ] default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag explicitly specified") -@average_energy MAR ( - q_y_x::MultivariateNormalDistributionsFamily, - q_a::MultivariateNormalDistributionsFamily, - q_Λ::Wishart, - meta::MARMeta -) = begin +@average_energy MAR (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::MARMeta) = begin ma, Va = mean_cov(q_a) myx, Vyx = mean_cov(q_y_x) mΛ = mean(q_Λ) order, ds = getorder(meta), getdimensionality(meta) F = Multivariate - dim = order*ds + dim = order * ds n = div(ndims(q_y_x), 2) - ma, Va = mean_cov(q_a) mA = mar_companion_matrix(order, ds, ma)[1:ds, 1:dim] - mx, Vx = ar_slice(F, myx, (dim+1):2dim), ar_slice(F, Vyx, (dim+1):2dim, (dim+1):2dim) + mx, Vx = ar_slice(F, myx, (dim + 1):(2dim)), ar_slice(F, Vyx, (dim + 1):(2dim), (dim + 1):(2dim)) my1, Vy1 = myx[1:ds], Vyx[1:ds, 1:ds] - Vy1x = ar_slice(F, Vyx, 1:ds, dim+1:2dim) + Vy1x = ar_slice(F, Vyx, 1:ds, (dim + 1):(2dim)) # @show Vyx # @show Vy1x @@ -57,16 +50,16 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl es = [uvector(ds, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] - g₁ = my1'*mΛ*my1 + tr(Vy1*mΛ) - g₂ = mx'*mA'*mΛ*my1 + tr(Vy1x*mA'*mΛ) + g₁ = my1' * mΛ * my1 + tr(Vy1 * mΛ) + g₂ = mx' * mA' * mΛ * my1 + tr(Vy1x * mA' * mΛ) g₃ = g₂ - G = sum(sum(es[i]'*mΛ*es[j]*Fs[i]*(ma*ma' + Va)*Fs[j]' for i in 1:ds) for j in 1:ds) - g₄ = mx'*G*mx + tr(Vx*G) - AE = n/2*log2π - 0.5*mean(logdet, q_Λ) + 0.5*(g₁ - g₂ - g₃ + g₄) + G = sum(sum(es[i]' * mΛ * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:ds) for j in 1:ds) + g₄ = mx' * G * mx + tr(Vx * G) + AE = n / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ - g₂ - g₃ + g₄) if order > 1 AE += entropy(q_y_x) - idc = LazyArrays.Vcat(1:ds, dim+1:2dim) + idc = LazyArrays.Vcat(1:ds, (dim + 1):(2dim)) myx_n = view(myx, idc) Vyx_n = view(Vyx, idc, idc) q_y_x = MvNormalMeanCovariance(myx_n, Vyx_n) @@ -77,20 +70,16 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl end @average_energy MAR ( - q_y::MultivariateNormalDistributionsFamily, - q_x::MultivariateNormalDistributionsFamily, - q_a::MultivariateNormalDistributionsFamily, - q_Λ::Wishart, - meta::MARMeta + q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::MARMeta ) = begin - ma, Va = mean_cov(q_a) + ma, Va = mean_cov(q_a) my, Vy = mean_cov(q_y) mx, Vx = mean_cov(q_y) - mΛ = mean(q_Λ) + mΛ = mean(q_Λ) order, ds = getorder(meta), getdimensionality(meta) F = Multivariate - dim = order*ds + dim = order * ds n = dim ma, Va = mean_cov(q_a) @@ -102,12 +91,12 @@ end es = [uvector(ds, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] - g₁ = my1'*mΛ*my1 + tr(Vy1*mΛ) - g₂ = -mx'*mA'*mΛ*my1 + g₁ = my1' * mΛ * my1 + tr(Vy1 * mΛ) + g₂ = -mx' * mA' * mΛ * my1 g₃ = -g₂ - G = sum(sum(es[i]'*mΛ*es[j]*Fs[i]*(ma*ma' + Va)*Fs[j]' for i in 1:ds) for j in 1:ds) - g₄ = mx'*G*mx + tr(Vx*G) - AE = n/2*log2π - 0.5*mean(logdet, q_Λ) + 0.5*(g₁ + g₂ + g₃ + g₄) + G = sum(sum(es[i]' * mΛ * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:ds) for j in 1:ds) + g₄ = mx' * G * mx + tr(Vx * G) + AE = n / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ + g₂ + g₃ + g₄) if order > 1 AE += entropy(q_y) @@ -120,13 +109,13 @@ end # Helpers for AR rules function mask_mar(order, dimension, index) - F = zeros(dimension*order, dimension*dimension*order) + F = zeros(dimension * order, dimension * dimension * order) rows = repeat([dimension], order) - cols = repeat([dimension], dimension*order) + cols = repeat([dimension], dimension * order) FB = BlockArrays.BlockArray(F, rows, cols) for k in 1:order - for j in 1:dimension*order - if j == index + (k-1)*dimension + for j in 1:(dimension * order) + if j == index + (k - 1) * dimension view(FB, BlockArrays.Block(k, j)) .= diageye(dimension) end end @@ -136,33 +125,32 @@ end function mar_transition(order, Λ) dim = size(Λ, 1) - W = 1.0*diageye(dim*order) + W = 1.0 * diageye(dim * order) W[1:dim, 1:dim] = Λ return W end - function mar_shift(order, ds) - dim = order*ds + dim = order * ds S = diageye(dim) - for i in dim:-1:ds+1 - S[i,:] = S[i-ds, :] + for i in dim:-1:(ds + 1) + S[i, :] = S[i - ds, :] end S[1:ds, :] = zeros(ds, dim) return S end -function uvector(dim, pos=1) +function uvector(dim, pos = 1) u = zeros(dim) u[pos] = 1 return dim == 1 ? u[pos] : u end function mar_companion_matrix(order, ds, a) - dim = order*ds + dim = order * ds S = mar_shift(order, ds) es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] - L = S .+ sum(es[i]*a'*Fs[i]' for i in 1:ds) + L = S .+ sum(es[i] * a' * Fs[i]' for i in 1:ds) return L end diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index c8f2459c9..0777efb7b 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -1,16 +1,15 @@ @rule MAR(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin - order, ds = getorder(meta), getdimensionality(meta) F = Multivariate - dim = order*ds + dim = order * ds - m, V = mean_cov(q_y_x) + m, V = mean_cov(q_y_x) - my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) - mx, Vx = ar_slice(F, m, (dim+1):2dim), ar_slice(F, V, (dim+1):2dim, (dim+1):2dim) - Vyx = ar_slice(F, V, 1:dim, dim+1:2dim) + my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) + mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) + Vyx = ar_slice(F, V, 1:dim, (dim + 1):(2dim)) mΛ = mean(q_Λ) mW = mar_transition(order, mΛ) @@ -21,23 +20,21 @@ S = mar_shift(order, ds) # NOTE: prove that sum(Fs[i]'*((mx*mx'+Vx')*S')*mW*es[i] for i in 1:ds) == 0.0 - D = sum(sum(es[i]'*mW*es[j]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) - z = sum(Fs[i]'*(mx*my'+Vyx')*mW*es[i] for i in 1:ds) + D = sum(sum(es[i]' * mW * es[j] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:ds) for j in 1:ds) + z = sum(Fs[i]' * (mx * my' + Vyx') * mW * es[i] for i in 1:ds) return MvNormalWeightedMeanPrecision(z, D) end - @rule MAR(:a, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin - order, ds = getorder(meta), getdimensionality(meta) F = Multivariate - - dim = order*ds - my, Vy = mean_cov(q_y) - mx, Vx = mean_cov(q_x) - mΛ = mean(q_Λ) + dim = order * ds + + my, Vy = mean_cov(q_y) + mx, Vx = mean_cov(q_x) + mΛ = mean(q_Λ) mW = mar_transition(order, mΛ) S = mar_shift(order, ds) @@ -46,9 +43,8 @@ end es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] - - D = sum(sum(es[j]'*mW*es[i]*Fs[i]'*(mx*mx' + Vx)*Fs[j] for i in 1:ds) for j in 1:ds) - z = sum(Fs[i]'*((mx*mx'+Vx')*S' + mx*my')*mW*es[i] for i in 1:ds) + D = sum(sum(es[j]' * mW * es[i] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:ds) for j in 1:ds) + z = sum(Fs[i]' * ((mx * mx' + Vx') * S' + mx * my') * mW * es[i] for i in 1:ds) return MvNormalWeightedMeanPrecision(z, D) -end \ No newline at end of file +end diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index f73d30c05..29f88cbae 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -1,62 +1,53 @@ -@rule MAR(:Λ, Marginalisation) ( - q_y_x::MultivariateNormalDistributionsFamily, - q_a::MultivariateNormalDistributionsFamily, - meta::MARMeta -) = begin +@rule MAR(:Λ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::MARMeta) = begin order, ds = getorder(meta), getdimensionality(meta) F = Multivariate - dim = order*ds - + dim = order * ds + ma, Va = mean_cov(q_a) mA = mar_companion_matrix(order, ds, ma) - m, V = mean_cov(q_y_x) - my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) - mx, Vx = ar_slice(F, m, (dim+1):2dim), ar_slice(F, V, (dim+1):2dim, (dim+1):2dim) - Vyx = ar_slice(F, V, 1:dim, dim+1:2dim) + m, V = mean_cov(q_y_x) + my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) + mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) + Vyx = ar_slice(F, V, 1:dim, (dim + 1):(2dim)) es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] S = mar_shift(order, ds) - G₁ = (my*my' + Vy)[1:ds, 1:ds] - G₂ = ((my*mx' + Vyx)*mA')[1:ds, 1:ds] + G₁ = (my * my' + Vy)[1:ds, 1:ds] + G₂ = ((my * mx' + Vyx) * mA')[1:ds, 1:ds] G₃ = transpose(G₂) - Ex_xx = mx*mx' + Vx - G₅ = sum(sum(es[i]*ma'*Fs[i]'Ex_xx*Fs[j]*ma*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - G₆ = sum(sum(es[i]*tr(Fs[i]'*Ex_xx*Fs[j]*Va)*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] + Ex_xx = mx * mx' + Vx + G₅ = sum(sum(es[i] * ma' * Fs[i]'Ex_xx * Fs[j] * ma * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] + G₆ = sum(sum(es[i] * tr(Fs[i]' * Ex_xx * Fs[j] * Va) * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] Δ = G₁ - G₂ - G₃ + G₅ + G₆ - return WishartMessage(ds+2, Δ) - + return WishartMessage(ds + 2, Δ) end -@rule MAR(:Λ, Marginalisation) ( - q_y::MultivariateNormalDistributionsFamily, - q_x::MultivariateNormalDistributionsFamily, - q_a::MultivariateNormalDistributionsFamily, - meta::MARMeta -) = begin - order, ds = getorder(meta), getdimensionality(meta) - F = Multivariate - dim = order*ds - - my, Vy = mean_cov(q_y) - mx, Vx = mean_cov(q_x) - ma, Va = mean_cov(q_a) - - mA = mar_companion_matrix(order, ds, ma) - - es = [uvector(dim, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] - S = mar_shift(order, ds) - G₁ = (my*my' + Vy)[1:ds, 1:ds] - G₂ = (my*mx'*mA')[1:ds, 1:ds] - G₃ = transpose(G₂) - Ex_xx = mx*mx' + Vx - G₅ = sum(sum(es[i]*ma'*Fs[j]'Ex_xx*Fs[i]*ma*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - G₆ = sum(sum(es[i]*tr(Va*Fs[i]'*Ex_xx*Fs[j])*es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - Δ = G₁ - G₂ - G₃ + G₅ + G₆ - - return WishartMessage(ds+2, Δ) -end \ No newline at end of file +@rule MAR(:Λ, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::MARMeta) = + begin + order, ds = getorder(meta), getdimensionality(meta) + F = Multivariate + dim = order * ds + + my, Vy = mean_cov(q_y) + mx, Vx = mean_cov(q_x) + ma, Va = mean_cov(q_a) + + mA = mar_companion_matrix(order, ds, ma) + + es = [uvector(dim, i) for i in 1:ds] + Fs = [mask_mar(order, ds, i) for i in 1:ds] + S = mar_shift(order, ds) + G₁ = (my * my' + Vy)[1:ds, 1:ds] + G₂ = (my * mx' * mA')[1:ds, 1:ds] + G₃ = transpose(G₂) + Ex_xx = mx * mx' + Vx + G₅ = sum(sum(es[i] * ma' * Fs[j]'Ex_xx * Fs[i] * ma * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] + G₆ = sum(sum(es[i] * tr(Va * Fs[i]' * Ex_xx * Fs[j]) * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] + Δ = G₁ - G₂ - G₃ + G₅ + G₆ + + return WishartMessage(ds + 2, Δ) + end diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl index 51b7a1deb..92a71a659 100644 --- a/src/rules/mv_autoregressive/marginals.jl +++ b/src/rules/mv_autoregressive/marginals.jl @@ -1,29 +1,20 @@ @marginalrule MAR(:y_x) ( - m_y::MultivariateNormalDistributionsFamily, - m_x::MultivariateNormalDistributionsFamily, - q_a::MultivariateNormalDistributionsFamily, - q_Λ::Any, - meta::MARMeta + m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta ) = begin return ar_y_x_marginal(m_y, m_x, q_a, q_Λ, meta) end function ar_y_x_marginal( - m_y::MultivariateNormalDistributionsFamily, - m_x::MultivariateNormalDistributionsFamily, - q_a::MultivariateNormalDistributionsFamily, - q_Λ::Any, - meta::MARMeta) - + m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta +) order, ds = getorder(meta), getdimensionality(meta) F = Multivariate - dim = order*ds - + dim = order * ds ma, Va = mean_cov(q_a) mΛ = mean(q_Λ) - + mA = mar_companion_matrix(order, ds, ma) mW = mar_transition(getorder(meta), mΛ) @@ -37,7 +28,7 @@ function ar_y_x_marginal( es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] - Ξ = inv_f_Vx + sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) + Ξ = inv_f_Vx + sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) W_11 = inv_b_Vy + mW diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index 20addd144..e191589ad 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -1,13 +1,12 @@ -@rule MAR(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = -begin +@rule MAR(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin ma, Va = mean_cov(q_a) my, Vy = mean_cov(m_y) mΛ = mean(q_Λ) order, ds = getorder(meta), getdimensionality(meta) - dim = order*ds + dim = order * ds mA = mar_companion_matrix(order, ds, ma) mW = mar_transition(getorder(meta), mΛ) @@ -15,38 +14,37 @@ begin es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] - Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) + Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) - Σ₁ = Hermitian(pinv(mA)*(Vy)*pinv(mA') + pinv(mA'*mW*mA)) + Σ₁ = Hermitian(pinv(mA) * (Vy) * pinv(mA') + pinv(mA' * mW * mA)) Ξ = (pinv(Σ₁) + Λ) - z = pinv(Σ₁)*pinv(mA)*my + z = pinv(Σ₁) * pinv(mA) * my return MvNormalWeightedMeanPrecision(z, Ξ) end @rule MAR(:x, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin - ma, Va = mean_cov(q_a) my, Vy = mean_cov(q_y) mΛ = mean(q_Λ) order, ds = getorder(meta), getdimensionality(meta) - dim = order*ds + dim = order * ds mA = mar_companion_matrix(order, ds, ma) mW = mar_transition(getorder(meta), mΛ) - + # this should be inside MARMeta es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] - Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) - Λ₀ = Hermitian(mA'*mW*mA) + Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) + Λ₀ = Hermitian(mA' * mW * mA) Ξ = Λ₀ + Λ - z = Λ₀*pinv(mA)*my + z = Λ₀ * pinv(mA) * my return MvNormalWeightedMeanPrecision(z, Ξ) -end \ No newline at end of file +end diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl index 2a3ffa678..b99ace9b3 100644 --- a/src/rules/mv_autoregressive/y.jl +++ b/src/rules/mv_autoregressive/y.jl @@ -1,5 +1,4 @@ -@rule MAR(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = -begin +@rule MAR(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin ma, Va = mean_cov(q_a) mx, Wx = mean_invcov(m_x) @@ -9,29 +8,27 @@ begin mA = mar_companion_matrix(order, ds, ma) mW = mar_transition(getorder(meta), mΛ) - dim = order*ds + dim = order * ds # this should be inside MARMeta es = [uvector(dim, i) for i in 1:ds] Fs = [mask_mar(order, ds, i) for i in 1:ds] - - Λ = sum(sum(es[j]'*mW*es[i]*Fs[j]*Va*Fs[i]' for i in 1:ds) for j in 1:ds) + Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) Ξ = Λ + Wx - z = Wx*mx + z = Wx * mx - Vy = mA*inv(Ξ)*mA' + inv(mW) - my = mA*inv(Ξ)*z + Vy = mA * inv(Ξ) * mA' + inv(mW) + my = mA * inv(Ξ) * z return MvNormalMeanCovariance(my, Vy) end @rule MAR(:y, Marginalisation) (q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin - order, ds = getorder(meta), getdimensionality(meta) mA = mar_companion_matrix(order, ds, mean(q_a)) mW = mar_transition(getorder(meta), mean(q_Λ)) return MvNormalMeanPrecision(mA * mean(q_x), mW) -end \ No newline at end of file +end From a490f8a6596d75213b6a8516d7990684ea413252 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Tue, 7 Mar 2023 12:41:27 +0100 Subject: [PATCH 37/48] fix tests --- src/variables/data.jl | 18 ++++--- test/variables/test_data.jl | 104 ++++++++++++++++++------------------ 2 files changed, 63 insertions(+), 59 deletions(-) diff --git a/src/variables/data.jl b/src/variables/data.jl index e5ff4bd0b..6600cda8f 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -76,7 +76,7 @@ datavar(name::Symbol, ::Type{D}, dims::Tuple) where {D} datavar(name::Symbol, ::Type{D}, dims::Vararg{Int}) where {D} = datavar(DataVariableCreationOptions(D), name, D, dims) datavar(options::DataVariableCreationOptions{S}, name::Symbol, ::Type{D}, collection_type::AbstractVariableCollectionType = VariableIndividual()) where {S, D} = - DataVariable{D, S}(name, collection_type, MarginalObservable(), Vector{MessageObservable{AbstractMessage}}(), options.subject, 0) + DataVariable{D, S}(name, collection_type, MarginalObservable(), Vector{MessageObservable{AbstractMessage}}(), options.subject, 0, options.isproxy, options.isused) function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, length::Int) where {D} return map(i -> datavar(similar(options), name, D, VariableVector(i)), 1:length) @@ -122,7 +122,7 @@ function Base.getindex(datavar::DataVariable, i...) error("Variable $(indexed_name(datavar)) has been indexed with `[$(join(i, ','))]`. Direct indexing of `data` variables is not allowed.") end -getlastindex(::DataVariable) = 1 +getlastindex(datavar::DataVariable) = degree(datavar) + 1 messageout(datavar::DataVariable, ::Int) = datavar.messageout messagein(datavar::DataVariable, ::Int) = error("It is not possible to get a reference for inbound message for datavar") @@ -175,10 +175,16 @@ _makemarginal(datavar::DataVariable) = error("It is not possible to setanonymous!(::DataVariable, ::Bool) = nothing -function setmessagein!(datavar::DataVariable, ::Int, messagein) - datavar.nconnected += 1 - datavar.isused = true - push!(datavar.input_messages, messagein) +function setmessagein!(datavar::DataVariable, index::Int, messagein) + if index === (degree(datavar) + 1) + push!(datavar.input_messages, messagein) + datavar.nconnected += 1 + datavar.isused = true + else + error( + "Inconsistent state in setmessagein! function for data variable $(datavar). `index` should be equal to `degree(datavar) + 1 = $(degree(datavar) + 1)`, $(index) is given instead" + ) + end return nothing end diff --git a/test/variables/test_data.jl b/test/variables/test_data.jl index 29c2bbff4..42b517a5b 100644 --- a/test/variables/test_data.jl +++ b/test/variables/test_data.jl @@ -4,7 +4,7 @@ using Test using ReactiveMP using Rocket -import ReactiveMP: DataVariableCreationOptions +import ReactiveMP: DataVariableCreationOptions, MessageObservable import ReactiveMP: collection_type, VariableIndividual, VariableVector, VariableArray, linear_index import ReactiveMP: getconst, proxy_variables import ReactiveMP: israndom, isproxy, isused, isconnected, setmessagein!, allows_missings @@ -44,69 +44,67 @@ import ReactiveMP: israndom, isproxy, isused, isconnected, setmessagein!, allows @test !israndom(variable) @test eltype(variable) === T @test name(variable) === sym + @test allows_missings(variable) === allow_missings @test collection_type(variable) isa VariableIndividual @test proxy_variables(variable) === nothing @test !isproxy(variable) - @test allows_missings(variable) === allow_missings - @test !israndom(v) - @test eltype(v) === type - @test name(v) === sym - @test collection_type(v) isa VariableIndividual - @test proxy_variables(v) === nothing - @test !isproxy(v) - @test !isused(v) - @test !isconnected(v) - - setmessagein!(v, 1, of(nothing)) - - @test isused(v) - @test isconnected(v) + @test !isused(variable) + @test !isconnected(variable) + + setmessagein!(variable, 1, MessageObservable()) + + @test isused(variable) + @test isconnected(variable) + + # `100` could a valid index, but messages should be initialized in order, previous was `1` + @test_throws ErrorException setmessagein!(variable, 100, MessageObservable()) end for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), n in (10, 20), allow_missings in (true, false) options = DataVariableCreationOptions(T, nothing, Val(allow_missings)) variables = datavar(options, sym, T, n) - @test !israndom(vs) - @test length(vs) === n - @test vs isa Vector - @test all(v -> !israndom(v), vs) - @test all(v -> name(v) === sym, vs) - @test all(v -> collection_type(v) isa VariableVector, vs) - @test all(t -> linear_index(collection_type(t[2])) === t[1], enumerate(vs)) - @test all(v -> eltype(v) === type, vs) - @test !isproxy(vs) - @test all(v -> !isproxy(v), vs) - @test all(v -> !isused(v), vs) - @test all(v -> !isconnected(v), vs) - @test test_updates(vs, type, (n,)) - - foreach(v -> setmessagein!(v, 1, of(nothing)), vs) - - @test all(v -> isused(v), vs) - @test all(v -> isconnected(v), vs) + @test !israndom(variables) + @test length(variables) === n + @test variables isa Vector + @test all(v -> !israndom(v), variables) + @test all(v -> name(v) === sym, variables) + @test all(v -> allows_missings(v) === allow_missings, variables) + @test all(v -> collection_type(v) isa VariableVector, variables) + @test all(t -> linear_index(collection_type(t[2])) === t[1], enumerate(variables)) + @test all(v -> eltype(v) === T, variables) + @test !isproxy(variables) + @test all(v -> !isproxy(v), variables) + @test all(v -> !isused(v), variables) + @test all(v -> !isconnected(v), variables) + @test test_updates(variables, T, (n,)) + + foreach(v -> setmessagein!(v, 1, MessageObservable()), variables) + + @test all(v -> isused(v), variables) + @test all(v -> isconnected(v), variables) end - for sym in (:x, :y, :z), type in (Float64, Int64, Vector{Float64}), l in (10, 20), r in (10, 20) - for vs in (datavar(sym, type, l, r), datavar(sym, type, (l, r))) - @test !israndom(vs) - @test size(vs) === (l, r) - @test length(vs) === l * r - @test vs isa Matrix - @test all(v -> !israndom(v), vs) - @test all(v -> name(v) === sym, vs) - @test all(v -> collection_type(v) isa VariableArray, vs) - @test all(t -> linear_index(collection_type(t[2])) === t[1], enumerate(vs)) - @test all(v -> eltype(v) === type, vs) - @test !isproxy(vs) - @test all(v -> !isproxy(v), vs) - @test all(v -> !isused(v), vs) - @test test_updates(vs, type, (l, r)) - - foreach(v -> setmessagein!(v, 1, of(nothing)), vs) - - @test all(v -> isused(v), vs) - @test all(v -> isconnected(v), vs) + for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), l in (10, 20), r in (10, 20) + for variables in (datavar(sym, T, l, r), datavar(sym, T, (l, r))) + @test !israndom(variables) + @test size(variables) === (l, r) + @test length(variables) === l * r + @test variables isa Matrix + @test all(v -> !israndom(v), variables) + @test all(v -> name(v) === sym, variables) + @test all(v -> collection_type(v) isa VariableArray, variables) + @test all(t -> linear_index(collection_type(t[2])) === t[1], enumerate(variables)) + @test all(v -> eltype(v) === T, variables) + @test !isproxy(variables) + @test all(v -> !isproxy(v), variables) + @test all(v -> !isused(v), variables) + @test test_updates(variables, T, (l, r)) + + foreach(v -> setmessagein!(v, 1, MessageObservable()), variables) + + @test all(v -> isused(v), variables) + @test all(v -> isconnected(v), variables) end end end From eaae1d2916a77ec1cc10b090d86e5791ce725163 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sun, 19 Mar 2023 09:47:20 +0100 Subject: [PATCH 38/48] Update rules --- branch.diff | 1963 ++++++++++++++++++++++ src/nodes/mv_autoregressive.jl | 13 +- src/rules/mv_autoregressive/a.jl | 7 +- src/rules/mv_autoregressive/lambda.jl | 6 +- src/rules/mv_autoregressive/marginals.jl | 7 +- src/rules/mv_autoregressive/x.jl | 13 +- src/rules/mv_autoregressive/y.jl | 7 +- 7 files changed, 1987 insertions(+), 29 deletions(-) create mode 100644 branch.diff diff --git a/branch.diff b/branch.diff new file mode 100644 index 000000000..b14a433b6 --- /dev/null +++ b/branch.diff @@ -0,0 +1,1963 @@ +diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml +index 8d5d2580..93d16876 100644 +--- a/.github/workflows/ci.yml ++++ b/.github/workflows/ci.yml +@@ -1,9 +1,17 @@ + name: CI + on: + pull_request: ++ types: [ready_for_review,reopened,synchronize] ++ pull_request_review: ++ types: [submitted,edited] + push: ++ branches: ++ - 'master' ++ tags: '*' ++ check_run: ++ types: [rerequested] + schedule: +- - cron: '44 9 16 * *' # run the cron job one time per month ++ - cron: '0 8 * * 1' # run the cron job one time per week on Monday 8:00 AM + jobs: + format: + name: Julia Formatter +@@ -17,6 +25,7 @@ jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} ++ continue-on-error: ${{ contains(matrix.version, 'nightly') }} + needs: format + strategy: + fail-fast: false +@@ -110,4 +119,4 @@ jobs: + env: + PYTHON: "" + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} +- DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} +\ No newline at end of file ++ DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} +diff --git a/Makefile b/Makefile +index 95a2851e..231f80bb 100644 +--- a/Makefile ++++ b/Makefile +@@ -30,8 +30,8 @@ docs: doc_init ## Generate documentation + + .PHONY: test + +-test: ## Run tests (use testset="folder1:test1 folder2:test2" argument to run reduced testset) +- julia -e 'import Pkg; Pkg.activate("."); Pkg.test(test_args = split("$(testset)") .|> string)' ++test: ## Run tests (use test_args="folder1:test1 folder2:test2" argument to run reduced testset) ++ julia -e 'import Pkg; Pkg.activate("."); Pkg.test(test_args = split("$(test_args)") .|> string)' + + help: ## Display this help + @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_-]+:.*?##/ { printf " \033[36m%-24s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) +\ No newline at end of file +diff --git a/Project.toml b/Project.toml +index d63ca068..c4754725 100644 +--- a/Project.toml ++++ b/Project.toml +@@ -1,10 +1,9 @@ + name = "ReactiveMP" + uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3" + authors = ["Dmitry Bagaev ", "Albert Podusenko ", "Bart van Erp ", "Ismail Senoz "] +-version = "3.6.1" ++version = "3.7.2" + + [deps] +-BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" + Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" + DomainIntegrals = "cc6bae93-f070-4015-88fd-838f9505a86c" +@@ -43,7 +42,7 @@ MacroTools = "0.5" + Optim = "1.0.0" + PositiveFactorizations = "0.2" + Requires = "1" +-Rocket = "1.6.0" ++Rocket = "1.7.0" + SpecialFunctions = "1.4, 2" + StaticArrays = "1.2" + StatsBase = "0.33" +diff --git a/docs/src/extra/contributing.md b/docs/src/extra/contributing.md +index c9e0e317..9a84ca4c 100644 +--- a/docs/src/extra/contributing.md ++++ b/docs/src/extra/contributing.md +@@ -81,8 +81,8 @@ a new release of the broken dependecy is available. + + - `make help`: Shows help snippet + - `make test`: Run tests, supports extra arguments +- - `make test testset="distributions:normal_mean_variance"` would run tests only from `distributions/test_normal_mean_variance.jl` +- - `make test testset="distributions:normal_mean_variance models:lgssm"` would run tests both from `distributions/test_normal_mean_variance.jl` and `models/test_lgssm.jl` ++ - `make test test_args="distributions:normal_mean_variance"` would run tests only from `distributions/test_normal_mean_variance.jl` ++ - `make test test_args="distributions:normal_mean_variance models:lgssm"` would run tests both from `distributions/test_normal_mean_variance.jl` and `models/test_lgssm.jl` + - `make docs`: Compile documentation + - `make benchmark`: Run simple benchmark + - `make lint`: Check codestyle +diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl +index 7b5c860c..fc0186b4 100644 +--- a/src/ReactiveMP.jl ++++ b/src/ReactiveMP.jl +@@ -146,7 +146,6 @@ include("nodes/matrix_dirichlet.jl") + include("nodes/dirichlet.jl") + include("nodes/bernoulli.jl") + include("nodes/gcv.jl") +-include("nodes/kernel_gcv.jl") + include("nodes/wishart.jl") + include("nodes/wishart_inverse.jl") + include("nodes/normal_mixture.jl") +@@ -154,7 +153,6 @@ include("nodes/gamma_mixture.jl") + include("nodes/dot_product.jl") + include("nodes/transition.jl") + include("nodes/autoregressive.jl") +-include("nodes/mv_autoregressive.jl") + include("nodes/bifm.jl") + include("nodes/bifm_helper.jl") + include("nodes/probit.jl") +diff --git a/src/constraints/form.jl b/src/constraints/form.jl +index a746cdd6..c6c2edc4 100644 +--- a/src/constraints/form.jl ++++ b/src/constraints/form.jl +@@ -172,7 +172,7 @@ function is_point_mass_form_constraint(composite::CompositeFormConstraint) + is_point_mass = map(is_point_mass_form_constraint, composite.constraints) + pmindex = findnext(is_point_mass, 1) + if pmindex !== nothing && pmindex !== length(is_point_mass) +- error("Composite form constraint supports point mass constraint only at the end of the form constrains specification.") ++ error("Composite form constraint supports point mass constraint only at the end of the form constraints specification.") + end + return last(is_point_mass) + end +diff --git a/src/distributions/bernoulli.jl b/src/distributions/bernoulli.jl +index 9757f0aa..bb30b485 100644 +--- a/src/distributions/bernoulli.jl ++++ b/src/distributions/bernoulli.jl +@@ -47,6 +47,10 @@ function prod(::ProdAnalytical, left::Bernoulli, right::Categorical) + return Categorical(ReactiveMP.normalize!(p_new, 1)) + end + ++prod_analytical_rule(::Type{<:Categorical}, ::Type{<:Bernoulli}) = ProdAnalyticalRuleAvailable() ++ ++prod(::ProdAnalytical, left::Categorical, right::Bernoulli) = prod(ProdAnalytical(), right, left) ++ + function compute_logscale(new_dist::Bernoulli, left_dist::Bernoulli, right_dist::Bernoulli) + left_p = succprob(left_dist) + right_p = succprob(right_dist) +diff --git a/src/distributions/beta.jl b/src/distributions/beta.jl +index 9541c90a..129c4f7f 100644 +--- a/src/distributions/beta.jl ++++ b/src/distributions/beta.jl +@@ -1,7 +1,9 @@ + export Beta ++export BetaNaturalParameters + + import Distributions: Beta, params +-import SpecialFunctions: digamma, logbeta ++import SpecialFunctions: digamma, logbeta, loggamma ++import StatsFuns: betalogpdf + + vague(::Type{<:Beta}) = Beta(1.0, 1.0) + +@@ -27,3 +29,47 @@ function mean(::typeof(mirrorlog), dist::Beta) + a, b = params(dist) + return digamma(b) - digamma(a + b) + end ++ ++struct BetaNaturalParameters{T <: Real} <: NaturalParameters ++ αm1::T ++ βm1::T ++end ++ ++BetaNaturalParameters(αm1::Real, βm1::Real) = BetaNaturalParameters(promote(αm1, βm1)...) ++BetaNaturalParameters(αm1::Integer, βm1::Integer) = BetaNaturalParameters(float(αm1), float(βm1)) ++ ++Base.convert(::Type{BetaNaturalParameters}, a::Real, b::Real) = convert(BetaNaturalParameters{promote_type(typeof(a), typeof(b))}, a, b) ++ ++Base.convert(::Type{BetaNaturalParameters{T}}, a::Real, b::Real) where {T} = BetaNaturalParameters(convert(T, a), convert(T, b)) ++ ++Base.convert(::Type{BetaNaturalParameters}, vec::AbstractVector) = convert(BetaNaturalParameters{eltype(vec)}, vec) ++ ++Base.convert(::Type{BetaNaturalParameters{T}}, vec::AbstractVector) where {T} = BetaNaturalParameters(convert(AbstractVector{T}, vec)) ++ ++function isproper(params::BetaNaturalParameters) ++ return ((params.αm1 + 1) > 0) && ((params.βm1 + 1) > 0) ++end ++ ++naturalparams(dist::Beta) = BetaNaturalParameters(dist.α - 1, dist.β - 1) ++ ++function Base.convert(::Type{Distribution}, η::BetaNaturalParameters) ++ return Beta(η.αm1 + 1, η.βm1 + 1, check_args = false) ++end ++ ++function Base.vec(p::BetaNaturalParameters) ++ return [p.αm1, p.βm1] ++end ++ ++ReactiveMP.as_naturalparams(::Type{T}, args...) where {T <: BetaNaturalParameters} = convert(BetaNaturalParameters, args...) ++ ++function BetaNaturalParameters(v::AbstractVector{T}) where {T <: Real} ++ @assert length(v) === 2 "`BetaNaturalParameters` must accept a vector of length `2`." ++ return BetaNaturalParameters(v[1], v[2]) ++end ++ ++lognormalizer(params::BetaNaturalParameters) = logbeta(params.αm1 + 1, params.βm1 + 1) ++logpdf(params::BetaNaturalParameters, x) = betalogpdf(params.αm1 + 1, params.βm1 + 1, x) ++ ++function Base.:-(left::BetaNaturalParameters, right::BetaNaturalParameters) ++ return BetaNaturalParameters(left.αm1 - right.αm1, left.βm1 - right.βm1) ++end +diff --git a/src/distributions/mv_normal_mean_covariance.jl b/src/distributions/mv_normal_mean_covariance.jl +index 7938bf16..079016ca 100644 +--- a/src/distributions/mv_normal_mean_covariance.jl ++++ b/src/distributions/mv_normal_mean_covariance.jl +@@ -26,6 +26,13 @@ function MvNormalMeanCovariance(μ::AbstractVector{T}) where {T} + return MvNormalMeanCovariance(μ, convert(AbstractArray{T}, ones(length(μ)))) + end + ++function MvNormalMeanCovariance(μ::AbstractVector{T1}, Σ::UniformScaling{T2}) where {T1, T2} ++ T = promote_type(T1, T2) ++ μ_new = convert(AbstractArray{T}, μ) ++ Σ_new = convert(UniformScaling{T}, Σ)(length(μ)) ++ return MvNormalMeanCovariance(μ_new, Σ_new) ++end ++ + Distributions.distrname(::MvNormalMeanCovariance) = "MvNormalMeanCovariance" + + function weightedmean(dist::MvNormalMeanCovariance) +@@ -88,7 +95,9 @@ function Base.prod(::ProdAnalytical, left::MvNormalMeanCovariance, right::MvNorm + return MvNormalWeightedMeanPrecision(xi_left + xi_right, W_left + W_right) + end + +-function Base.prod(::ProdAnalytical, left::MvNormalMeanCovariance{T1}, right::MvNormalMeanCovariance{T2}) where {T1 <: LinearAlgebra.BlasFloat, T2 <: LinearAlgebra.BlasFloat} ++function Base.prod( ++ ::ProdAnalytical, left::MvNormalMeanCovariance{T1, <:AbstractVector, <:Matrix}, right::MvNormalMeanCovariance{T2, <:AbstractVector, <:Matrix} ++) where {T1 <: LinearAlgebra.BlasFloat, T2 <: LinearAlgebra.BlasFloat} + + # start with parameters of left + xi, W = weightedmean_precision(left) +diff --git a/src/distributions/mv_normal_mean_precision.jl b/src/distributions/mv_normal_mean_precision.jl +index 889b0014..815eb424 100644 +--- a/src/distributions/mv_normal_mean_precision.jl ++++ b/src/distributions/mv_normal_mean_precision.jl +@@ -26,6 +26,13 @@ function MvNormalMeanPrecision(μ::AbstractVector{T}) where {T} + return MvNormalMeanPrecision(μ, convert(AbstractArray{T}, ones(length(μ)))) + end + ++function MvNormalMeanPrecision(μ::AbstractVector{T1}, Λ::UniformScaling{T2}) where {T1, T2} ++ T = promote_type(T1, T2) ++ μ_new = convert(AbstractArray{T}, μ) ++ Λ_new = convert(UniformScaling{T}, Λ)(length(μ)) ++ return MvNormalMeanPrecision(μ_new, Λ_new) ++end ++ + Distributions.distrname(::MvNormalMeanPrecision) = "MvNormalMeanPrecision" + + weightedmean(dist::MvNormalMeanPrecision) = precision(dist) * mean(dist) +@@ -92,7 +99,9 @@ function Base.prod(::ProdAnalytical, left::MvNormalMeanPrecision, right::MvNorma + return MvNormalWeightedMeanPrecision(xi, W) + end + +-function Base.prod(::ProdAnalytical, left::MvNormalMeanPrecision{T1}, right::MvNormalMeanPrecision{T2}) where {T1 <: LinearAlgebra.BlasFloat, T2 <: LinearAlgebra.BlasFloat} ++function Base.prod( ++ ::ProdAnalytical, left::MvNormalMeanPrecision{T1, <:AbstractVector, <:Matrix}, right::MvNormalMeanPrecision{T2, <:AbstractVector, <:Matrix} ++) where {T1 <: LinearAlgebra.BlasFloat, T2 <: LinearAlgebra.BlasFloat} + W = precision(left) + precision(right) + + # fast & efficient implementation of xi = precision(right)*mean(right) + precision(left)*mean(left) +diff --git a/src/distributions/mv_normal_weighted_mean_precision.jl b/src/distributions/mv_normal_weighted_mean_precision.jl +index 86aa046e..7137ccb4 100644 +--- a/src/distributions/mv_normal_weighted_mean_precision.jl ++++ b/src/distributions/mv_normal_weighted_mean_precision.jl +@@ -26,6 +26,13 @@ function MvNormalWeightedMeanPrecision(xi::AbstractVector{T}) where {T} + return MvNormalWeightedMeanPrecision(xi, convert(AbstractArray{T}, ones(length(xi)))) + end + ++function MvNormalWeightedMeanPrecision(xi::AbstractVector{T1}, Λ::UniformScaling{T2}) where {T1, T2} ++ T = promote_type(T1, T2) ++ xi_new = convert(AbstractArray{T}, xi) ++ Λ_new = convert(UniformScaling{T}, Λ)(length(xi)) ++ return MvNormalWeightedMeanPrecision(xi_new, Λ_new) ++end ++ + Distributions.distrname(::MvNormalWeightedMeanPrecision) = "MvNormalWeightedMeanPrecision" + + weightedmean(dist::MvNormalWeightedMeanPrecision) = dist.xi +diff --git a/src/distributions/normal_mean_precision.jl b/src/distributions/normal_mean_precision.jl +index d6b9ef8a..b5083756 100644 +--- a/src/distributions/normal_mean_precision.jl ++++ b/src/distributions/normal_mean_precision.jl +@@ -11,6 +11,12 @@ NormalMeanPrecision(μ::Real, w::Real) = NormalMeanPrecision(promote(μ, w + NormalMeanPrecision(μ::Integer, w::Integer) = NormalMeanPrecision(float(μ), float(w)) + NormalMeanPrecision(μ::Real) = NormalMeanPrecision(μ, one(μ)) + NormalMeanPrecision() = NormalMeanPrecision(0.0, 1.0) ++function NormalMeanPrecision(μ::T1, w::UniformScaling{T2}) where {T1 <: Real, T2} ++ T = promote_type(T1, T2) ++ μ_new = convert(T, μ) ++ w_new = convert(T, w.λ) ++ return NormalMeanPrecision(μ_new, w_new) ++end + + Distributions.@distr_support NormalMeanPrecision -Inf Inf + +diff --git a/src/distributions/normal_mean_variance.jl b/src/distributions/normal_mean_variance.jl +index 759f319d..2b1770b9 100644 +--- a/src/distributions/normal_mean_variance.jl ++++ b/src/distributions/normal_mean_variance.jl +@@ -11,6 +11,12 @@ NormalMeanVariance(μ::Real, v::Real) = NormalMeanVariance(promote(μ, v). + NormalMeanVariance(μ::Integer, v::Integer) = NormalMeanVariance(float(μ), float(v)) + NormalMeanVariance(μ::T) where {T <: Real} = NormalMeanVariance(μ, one(T)) + NormalMeanVariance() = NormalMeanVariance(0.0, 1.0) ++function NormalMeanVariance(μ::T1, v::UniformScaling{T2}) where {T1 <: Real, T2} ++ T = promote_type(T1, T2) ++ μ_new = convert(T, μ) ++ v_new = convert(T, v.λ) ++ return NormalMeanVariance(μ_new, v_new) ++end + + Distributions.@distr_support NormalMeanVariance -Inf Inf + +diff --git a/src/distributions/normal_weighted_mean_precision.jl b/src/distributions/normal_weighted_mean_precision.jl +index f68a6abe..c17f9fb3 100644 +--- a/src/distributions/normal_weighted_mean_precision.jl ++++ b/src/distributions/normal_weighted_mean_precision.jl +@@ -11,6 +11,12 @@ NormalWeightedMeanPrecision(xi::Real, w::Real) = NormalWeightedMeanPrecisi + NormalWeightedMeanPrecision(xi::Integer, w::Integer) = NormalWeightedMeanPrecision(float(xi), float(w)) + NormalWeightedMeanPrecision(xi::Real) = NormalWeightedMeanPrecision(xi, one(xi)) + NormalWeightedMeanPrecision() = NormalWeightedMeanPrecision(0.0, 1.0) ++function NormalWeightedMeanPrecision(xi::T1, w::UniformScaling{T2}) where {T1 <: Real, T2} ++ T = promote_type(T1, T2) ++ xi_new = convert(T, xi) ++ w_new = convert(T, w.λ) ++ return NormalWeightedMeanPrecision(xi_new, w_new) ++end + + Distributions.@distr_support NormalWeightedMeanPrecision -Inf Inf + +diff --git a/src/distributions/pointmass.jl b/src/distributions/pointmass.jl +index 38c15c24..c4612623 100644 +--- a/src/distributions/pointmass.jl ++++ b/src/distributions/pointmass.jl +@@ -1,5 +1,7 @@ + export PointMass, getpointmass + ++using LinearAlgebra: UniformScaling, I ++ + import Distributions: mean, var, cov, std, insupport, pdf, logpdf, entropy + import Base: ndims, precision, getindex, size, convert, isapprox, eltype + import SpecialFunctions: loggamma, logbeta +@@ -13,12 +15,14 @@ end + variate_form(::PointMass{T}) where {T <: Real} = Univariate + variate_form(::PointMass{V}) where {T, V <: AbstractVector{T}} = Multivariate + variate_form(::PointMass{M}) where {T, M <: AbstractMatrix{T}} = Matrixvariate ++variate_form(::PointMass{U}) where {T, U <: UniformScaling{T}} = Matrixvariate + + ## + + sampletype(distribution::PointMass{T}) where {T} = T + + getpointmass(distribution::PointMass) = distribution.point ++getpointmass(point::Union{Real, AbstractArray}) = point + + ## + +@@ -111,6 +115,31 @@ convert_eltype(::Type{PointMass}, ::Type{T}, distribution::PointMass{R}) where { + + Base.eltype(::PointMass{M}) where {T <: Real, M <: AbstractMatrix{T}} = T + ++# UniformScaling-based matrixvariate point mass ++ ++Distributions.insupport(distribution::PointMass{M}, x::UniformScaling) where {T <: Real, M <: UniformScaling{T}} = x == getpointmass(distribution) ++Distributions.pdf(distribution::PointMass{M}, x::UniformScaling) where {T <: Real, M <: UniformScaling{T}} = Distributions.insupport(distribution, x) ? one(T) : zero(T) ++Distributions.logpdf(distribution::PointMass{M}, x::UniformScaling) where {T <: Real, M <: UniformScaling{T}} = Distributions.insupport(distribution, x) ? zero(T) : convert(T, -Inf) ++ ++Distributions.mean(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = getpointmass(distribution) ++Distributions.mode(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = mean(distribution) ++Distributions.var(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = zero(T) * I ++Distributions.std(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = zero(T) * I ++Distributions.cov(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = error("Distributions.cov(::PointMass{ <: UniformScaling }) is not defined") ++ ++probvec(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = error("probvec(::PointMass{ <: UniformScaling }) is not defined") ++ ++mean(::typeof(inv), distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = inv(mean(distribution)) ++mean(::typeof(cholinv), distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = inv(mean(distribution)) ++ ++Base.precision(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = one(T) ./ cov(distribution) ++Base.ndims(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = size(mean(distribution)) ++ ++convert_eltype(::Type{PointMass}, ::Type{T}, distribution::PointMass{R}) where {T <: Real, R <: UniformScaling} = PointMass(convert(AbstractMatrix{T}, getpointmass(distribution))) ++convert_eltype(::Type{PointMass}, ::Type{T}, distribution::PointMass{R}) where {T <: AbstractMatrix, R <: UniformScaling} = PointMass(convert(T, getpointmass(distribution))) ++ ++Base.eltype(::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = T ++ + Base.isapprox(left::PointMass, right::PointMass; kwargs...) = Base.isapprox(getpointmass(left), getpointmass(right); kwargs...) + Base.isapprox(left::PointMass, right; kwargs...) = false + Base.isapprox(left, right::PointMass; kwargs...) = false +diff --git a/src/marginal.jl b/src/marginal.jl +index 72cd8819..75a39143 100644 +--- a/src/marginal.jl ++++ b/src/marginal.jl +@@ -109,6 +109,11 @@ struct SkipInitial <: MarginalSkipStrategy end + struct SkipClampedAndInitial <: MarginalSkipStrategy end + struct IncludeAll <: MarginalSkipStrategy end + ++Base.broadcastable(::SkipClamped) = Ref(SkipClamped()) ++Base.broadcastable(::SkipInitial) = Ref(SkipInitial()) ++Base.broadcastable(::SkipClampedAndInitial) = Ref(SkipClampedAndInitial()) ++Base.broadcastable(::IncludeAll) = Ref(IncludeAll()) ++ + apply_skip_filter(observable, ::SkipClamped) = observable |> filter(v -> !is_clamped(v)) + apply_skip_filter(observable, ::SkipInitial) = observable |> filter(v -> !is_initial(v)) + apply_skip_filter(observable, ::SkipClampedAndInitial) = observable |> filter(v -> !is_initial(v) && !is_clamped(v)) +diff --git a/src/node.jl b/src/node.jl +index 56b5e5e1..605d0c78 100644 +--- a/src/node.jl ++++ b/src/node.jl +@@ -866,7 +866,7 @@ function activate!(factornode::AbstractFactorNode, options) + vmessageout = combineLatest((msgs_observable, marginals_observable), PushNew()) # TODO check PushEach + vmessageout = apply_pipeline_stage(get_pipeline_stages(interface), factornode, vtag, vmessageout) + +- mapping = let messagemap = MessageMapping(fform, vtag, vconstraint, msgs_names, marginal_names, meta, addons, factornode) ++ mapping = let messagemap = MessageMapping(fform, vtag, vconstraint, msgs_names, marginal_names, meta, addons, node_if_required(fform, factornode)) + (dependencies) -> VariationalMessage(dependencies[1], dependencies[2], messagemap) + end + +@@ -939,11 +939,11 @@ function getmarginal!(factornode::FactorNode, localmarginal::FactorNodeLocalMarg + vtag = Val{name(localmarginal)} + meta = metadata(factornode) + +- mapping = MarginalMapping(fform, vtag, msgs_names, marginal_names, meta, factornode) ++ mapping = MarginalMapping(fform, vtag, msgs_names, marginal_names, meta, node_if_required(fform, factornode)) + # TODO: discontinue operator is needed for loopy belief propagation? Check + marginalout = combineLatest((msgs_observable, marginals_observable), PushNew()) |> discontinue() |> map(Marginal, mapping) + +- connect!(cmarginal, marginalout) # MarginalObservable has RecentSubject by default, there is no need to share_recent() here ++ connect!(cmarginal, marginalout) + + return apply_skip_filter(cmarginal, skip_strategy) + end +@@ -955,7 +955,7 @@ end + make_node(node) + make_node(node, options) + +-Creates a factor node of a given type and options. See the list of avaialble factor nodes below. ++Creates a factor node of a given type and options. See the list of available factor nodes below. + + See also: [`@node`](@ref) + +diff --git a/src/nodes/autoregressive.jl b/src/nodes/autoregressive.jl +index 4e80c275..07a620fe 100644 +--- a/src/nodes/autoregressive.jl ++++ b/src/nodes/autoregressive.jl +@@ -1,4 +1,4 @@ +-export AR, Autoregressive, ARsafe, ARunsafe, ARMeta, ar_unit, ar_slice ++export AR, Autoregressive, ARsafe, ARunsafe, ARMeta + + import LazyArrays + import Distributions: VariateForm +diff --git a/src/nodes/delta/delta.jl b/src/nodes/delta/delta.jl +index d302ee7c..cfe34369 100644 +--- a/src/nodes/delta/delta.jl ++++ b/src/nodes/delta/delta.jl +@@ -79,8 +79,11 @@ end + # For missing rules error msg + rule_method_error_extract_fform(f::Type{<:DeltaFn}) = "DeltaFn{f}" + ++# `DeltaFn` requires an access to the node function, hence, node reference is required ++call_rule_is_node_required(::Type{<:DeltaFn}) = CallRuleNodeRequired() ++ + # For `@call_rule` and `@call_marginalrule` +-function call_rule_make_node(::UndefinedNodeFunctionalForm, fformtype::Type{<:DeltaFn}, nodetype::F, meta::DeltaMeta) where {F} ++function call_rule_make_node(::CallRuleNodeRequired, fformtype::Type{<:DeltaFn}, nodetype::F, meta::DeltaMeta) where {F} + # This node is not initialized properly, but we do not expect rules to access internal uninitialized fields. + # Doing so will most likely throw an error + return DeltaFnNode(nodetype, NodeInterface(:out, Marginalisation()), (), nothing, collect_meta(DeltaFn{F}, meta)) +diff --git a/src/nodes/kernel_gcv.jl b/src/nodes/kernel_gcv.jl +deleted file mode 100644 +index 8434eacc..00000000 +--- a/src/nodes/kernel_gcv.jl ++++ /dev/null +@@ -1,34 +0,0 @@ +-export KernelGCV, KernelGCVMetadata +- +-import LinearAlgebra: logdet, tr +- +-struct KernelGCVMetadata{F, A} +- kernelFn :: F +- approximation :: A +-end +- +-get_kernelfn(meta::KernelGCVMetadata) = meta.kernelFn +-get_approximation(meta::KernelGCVMetadata) = meta.approximation +- +-struct KernelGCV end +- +-@node KernelGCV Stochastic [y, x, z] +- +-# TODO: Remove in favor of Generic Functional Message +-struct FnWithApproximation{F, A} +- fn :: F +- approximation :: A +-end +- +-prod_analytical_rule(::Type{<:MultivariateNormalDistributionsFamily}, ::Type{<:FnWithApproximation}) = ProdAnalyticalRuleAvailable() +- +-function prod(::ProdAnalytical, left::MultivariateNormalDistributionsFamily, right::FnWithApproximation) +- μ, Σ = approximate_meancov(right.approximation, (s) -> exp(right.fn(s)), left) +- return MvNormalMeanCovariance(μ, Σ) +-end +- +-prod_analytical_rule(::Type{<:FnWithApproximation}, ::Type{<:MultivariateNormalDistributionsFamily}) = ProdAnalyticalRuleAvailable() +- +-function prod(::ProdAnalytical, left::FnWithApproximation, right::MultivariateNormalDistributionsFamily) +- return prod(ProdAnalytical(), right, left) +-end +diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl +deleted file mode 100644 +index 1cc739b5..00000000 +--- a/src/nodes/mv_autoregressive.jl ++++ /dev/null +@@ -1,156 +0,0 @@ +-export MAR, MvAutoregressive, MARMeta, mar_transition, mar_shift +- +-import LazyArrays, BlockArrays +-import StatsFuns: log2π +- +-struct MAR end +- +-const MvAutoregressive = MAR +- +-struct MARMeta +- order :: Int # order (lag) of MAR +- ds :: Int # dimensionality of MAR process, i.e., the number of correlated AR processes +- +- function MARMeta(order, ds = 2) +- if ds < 2 +- @error "ds parameter should be > 1. Use AR node if ds = 1" +- end +- return new(order, ds) +- end +-end +- +-getorder(meta::MARMeta) = meta.order +-getdimensionality(meta::MARMeta) = meta.ds +- +-@node MAR Stochastic [y, x, a, Λ] +- +-default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag explicitly specified") +- +-@average_energy MAR (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::MARMeta) = begin +- ma, Va = mean_cov(q_a) +- myx, Vyx = mean_cov(q_y_x) +- mΛ = mean(q_Λ) +- +- order, ds = getorder(meta), getdimensionality(meta) +- F = Multivariate +- dim = order * ds +- n = div(ndims(q_y_x), 2) +- +- ma, Va = mean_cov(q_a) +- mA = mar_companion_matrix(order, ds, ma)[1:ds, 1:dim] +- +- mx, Vx = ar_slice(F, myx, (dim + 1):(2dim)), ar_slice(F, Vyx, (dim + 1):(2dim), (dim + 1):(2dim)) +- my1, Vy1 = myx[1:ds], Vyx[1:ds, 1:ds] +- Vy1x = ar_slice(F, Vyx, 1:ds, (dim + 1):(2dim)) +- +- # @show Vyx +- # @show Vy1x +- +- # this should be inside MARMeta +- es = [uvector(ds, i) for i in 1:ds] +- Fs = [mask_mar(order, ds, i) for i in 1:ds] +- +- g₁ = my1' * mΛ * my1 + tr(Vy1 * mΛ) +- g₂ = mx' * mA' * mΛ * my1 + tr(Vy1x * mA' * mΛ) +- g₃ = g₂ +- G = sum(sum(es[i]' * mΛ * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:ds) for j in 1:ds) +- g₄ = mx' * G * mx + tr(Vx * G) +- AE = n / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ - g₂ - g₃ + g₄) +- +- if order > 1 +- AE += entropy(q_y_x) +- idc = LazyArrays.Vcat(1:ds, (dim + 1):(2dim)) +- myx_n = view(myx, idc) +- Vyx_n = view(Vyx, idc, idc) +- q_y_x = MvNormalMeanCovariance(myx_n, Vyx_n) +- AE -= entropy(q_y_x) +- end +- +- return AE +-end +- +-@average_energy MAR ( +- q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::MARMeta +-) = begin +- ma, Va = mean_cov(q_a) +- my, Vy = mean_cov(q_y) +- mx, Vx = mean_cov(q_y) +- mΛ = mean(q_Λ) +- +- order, ds = getorder(meta), getdimensionality(meta) +- F = Multivariate +- dim = order * ds +- n = dim +- +- ma, Va = mean_cov(q_a) +- mA = mar_companion_matrix(order, ds, ma)[1:ds, 1:dim] +- +- my1, Vy1 = my[1:ds], Vy[1:ds, 1:ds] +- +- # this should be inside MARMeta +- es = [uvector(ds, i) for i in 1:ds] +- Fs = [mask_mar(order, ds, i) for i in 1:ds] +- +- g₁ = my1' * mΛ * my1 + tr(Vy1 * mΛ) +- g₂ = -mx' * mA' * mΛ * my1 +- g₃ = -g₂ +- G = sum(sum(es[i]' * mΛ * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:ds) for j in 1:ds) +- g₄ = mx' * G * mx + tr(Vx * G) +- AE = n / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ + g₂ + g₃ + g₄) +- +- if order > 1 +- AE += entropy(q_y) +- q_y = MvNormalMeanCovariance(my1, Vy1) +- AE -= entropy(q_y) +- end +- +- return AE +-end +- +-# Helpers for AR rules +-function mask_mar(order, dimension, index) +- F = zeros(dimension * order, dimension * dimension * order) +- rows = repeat([dimension], order) +- cols = repeat([dimension], dimension * order) +- FB = BlockArrays.BlockArray(F, rows, cols) +- for k in 1:order +- for j in 1:(dimension * order) +- if j == index + (k - 1) * dimension +- view(FB, BlockArrays.Block(k, j)) .= diageye(dimension) +- end +- end +- end +- return Matrix(FB) +-end +- +-function mar_transition(order, Λ) +- dim = size(Λ, 1) +- W = 1.0 * diageye(dim * order) +- W[1:dim, 1:dim] = Λ +- return W +-end +- +-function mar_shift(order, ds) +- dim = order * ds +- S = diageye(dim) +- for i in dim:-1:(ds + 1) +- S[i, :] = S[i - ds, :] +- end +- S[1:ds, :] = zeros(ds, dim) +- return S +-end +- +-function uvector(dim, pos = 1) +- u = zeros(dim) +- u[pos] = 1 +- return dim == 1 ? u[pos] : u +-end +- +-function mar_companion_matrix(order, ds, a) +- dim = order * ds +- S = mar_shift(order, ds) +- es = [uvector(dim, i) for i in 1:ds] +- Fs = [mask_mar(order, ds, i) for i in 1:ds] +- L = S .+ sum(es[i] * a' * Fs[i]' for i in 1:ds) +- return L +-end +diff --git a/src/rule.jl b/src/rule.jl +index 4cef52ef..4447e166 100644 +--- a/src/rule.jl ++++ b/src/rule.jl +@@ -202,21 +202,48 @@ function call_rule_macro_parse_fn_args(inputs; specname, prefix, proxy) + return names_arg, values_arg + end + ++# This trait indicates that a node reference is required for a proper rule execution ++# Most of the message passing update rules do not require a node reference ++# An example of a rule that requires a node is the `delta`, that needs the node function ++struct CallRuleNodeRequired end ++ ++# This trait indicates that a node reference is not required for a proper rule execution ++# This is used by default ++struct CallRuleNodeNotRequired end ++ ++""" ++ call_rule_is_node_required(fformtype) ++ ++Returns either `CallRuleNodeRequired()` or `CallRuleNodeNotRequired()` depending on if a specific ++`fformtype` requires an access to the corresponding node in order to compute a message update rule. ++Returns `CallRuleNodeNotRequired()` for all known functional forms by default and `CallRuleNodeRequired()` for all unknown functional forms. ++""" ++call_rule_is_node_required(fformtype) = call_rule_is_node_required(as_node_functional_form(fformtype), fformtype) ++ ++call_rule_is_node_required(::ValidNodeFunctionalForm, fformtype) = CallRuleNodeNotRequired() ++call_rule_is_node_required(::UndefinedNodeFunctionalForm, fformtype) = CallRuleNodeRequired() ++ ++# Returns the `node` if it is required for a rule, otherwise returns `nothing` ++node_if_required(fformtype, node) = node_if_required(call_rule_is_node_required(fformtype), node) ++ ++node_if_required(::CallRuleNodeRequired, node) = node ++node_if_required(::CallRuleNodeNotRequired, node) = nothing ++ + """ + call_rule_create_node(::Type{ NodeType }, fformtype) + +-Creates a node object that will be used inside `@call_rule` macro. The node object always creates with the default options for factorisation. ++Creates a node object that will be used inside `@call_rule` macro. + """ + function call_rule_make_node(fformtype, nodetype, meta) +- return call_rule_make_node(ReactiveMP.as_node_functional_form(nodetype), fformtype, nodetype, meta) ++ return call_rule_make_node(call_rule_is_node_required(nodetype), fformtype, nodetype, meta) + end + +-function call_rule_make_node(::UndefinedNodeFunctionalForm, fformtype, nodetype, meta) +- return error("Cannot create a node of type `$nodetype` for the call rule routine.") ++function call_rule_make_node(::CallRuleNodeRequired, fformtype, nodetype, meta) ++ return error("Missing implementation for the `call_rule_make_node`. Cannot create a node of type `$nodetype` for the call rule routine.") + end + +-function call_rule_make_node(::ValidNodeFunctionalForm, fformtype, nodetype, meta) +- return make_node(nodetype, FactorNodeCreationOptions(nothing, meta, nothing)) ++function call_rule_make_node(::CallRuleNodeNotRequired, fformtype, nodetype, meta) ++ return nothing + end + + call_rule_macro_construct_on_arg(on_type, on_index::Nothing) = MacroHelpers.bottom_type(on_type) +diff --git a/src/rules/bernoulli/marginals.jl b/src/rules/bernoulli/marginals.jl +index 78146875..40a68b2b 100644 +--- a/src/rules/bernoulli/marginals.jl ++++ b/src/rules/bernoulli/marginals.jl +@@ -5,3 +5,7 @@ export marginalrule + p = prod(ProdAnalytical(), Beta(one(r) + r, 2one(r) - r), m_p) + return (out = m_out, p = p) + end ++ ++@marginalrule Bernoulli(:out_p) (m_out::Bernoulli, m_p::PointMass) = begin ++ return (out = prod(ProdAnalytical(), Bernoulli(mean(m_p)), m_out), p = m_p) ++end +diff --git a/src/rules/categorical/marginals.jl b/src/rules/categorical/marginals.jl +index 4befdbb1..88c8890f 100644 +--- a/src/rules/categorical/marginals.jl ++++ b/src/rules/categorical/marginals.jl +@@ -2,3 +2,8 @@ + @marginalrule Categorical(:out_p) (m_out::Categorical, m_p::PointMass) = begin + return (out = prod(ProdAnalytical(), Categorical(mean(m_p)), m_out), p = m_p) + end ++ ++@marginalrule Categorical(:out_p) (m_out::PointMass, m_p::Dirichlet) = begin ++ p = prod(ProdAnalytical(), Dirichlet(probvec(m_out) .+ one(eltype(probvec(m_out)))), m_p) ++ return (out = m_out, p = p) ++end +diff --git a/src/rules/kernel_gcv/marginals.jl b/src/rules/kernel_gcv/marginals.jl +deleted file mode 100644 +index 9745498f..00000000 +--- a/src/rules/kernel_gcv/marginals.jl ++++ /dev/null +@@ -1,33 +0,0 @@ +-export marginalrule +- +-@marginalrule KernelGCV(:y_x) (m_y::MvNormalMeanCovariance, m_x::MvNormalMeanCovariance, q_z::MvNormalMeanCovariance, meta::KernelGCVMetadata) = begin +- kernelfunction = get_kernelfn(meta) +- Λ = approximate_kernel_expectation(get_approximation(meta), (z) -> cholinv(kernelfunction(z)), q_z) +- +- Λy = invcov(m_y) +- Λx = invcov(m_x) +- +- wy = Λy * mean(m_y) +- wx = Λx * mean(m_x) +- +- C = cholinv([Λ+Λy -Λ; -Λ Λ+Λx]) +- m = C * [wy; wx] +- +- return MvNormalMeanCovariance(m, C) +-end +- +-@marginalrule KernelGCV(:y_x) (m_y::MvNormalMeanPrecision, m_x::MvNormalMeanPrecision, q_z::MvNormalMeanPrecision, meta::KernelGCVMetadata) = begin +- kernelfunction = get_kernelfn(meta) +- C = approximate_kernel_expectation(get_approximation(meta), (z) -> cholinv(kernelfunction(z)), q_z) +- +- Cy = invcov(m_y) +- Cx = invcov(m_x) +- +- wy = Cy * mean(m_y) +- wx = Cx * mean(m_x) +- +- Λ = [C+Cy -C; -C C+Cx] +- μ = cholinv(Λ) * [wy; wx] +- +- return MvNormalMeanPrecision(μ, Λ) +-end +diff --git a/src/rules/kernel_gcv/x.jl b/src/rules/kernel_gcv/x.jl +deleted file mode 100644 +index 29776006..00000000 +--- a/src/rules/kernel_gcv/x.jl ++++ /dev/null +@@ -1,13 +0,0 @@ +-export rule +- +-@rule KernelGCV(:x, Marginalisation) (m_y::MvNormalMeanCovariance, q_z::MvNormalMeanCovariance, meta::KernelGCVMetadata) = begin +- kernelfunction = get_kernelfn(meta) +- Λ_out = approximate_kernel_expectation(get_approximation(meta), (s) -> cholinv(kernelfunction(s)), q_z) +- return MvNormalMeanCovariance(mean(m_y), cov(m_y) + cholinv(Λ_out)) +-end +- +-@rule KernelGCV(:x, Marginalisation) (m_y::MvNormalMeanPrecision, q_z::MvNormalMeanPrecision, meta::KernelGCVMetadata) = begin +- kernelfunction = get_kernelfn(meta) +- Λ_out = approximate_kernel_expectation(get_approximation(meta), (s) -> cholinv(kernelfunction(s)), q_z) +- return MvNormalMeanPrecision(mean(m_y), cholinv(cov(m_y) + cholinv(Λ_out))) +-end +diff --git a/src/rules/kernel_gcv/y.jl b/src/rules/kernel_gcv/y.jl +deleted file mode 100644 +index a6a4b82f..00000000 +--- a/src/rules/kernel_gcv/y.jl ++++ /dev/null +@@ -1,13 +0,0 @@ +-export rule +- +-@rule KernelGCV(:y, Marginalisation) (m_x::MvNormalMeanCovariance, q_z::MvNormalMeanCovariance, meta::KernelGCVMetadata) = begin +- kernelfunction = get_kernelfn(meta) +- Λ_out = approximate_kernel_expectation(get_approximation(meta), (s) -> cholinv(kernelfunction(s)), q_z) +- return MvNormalMeanCovariance(mean(m_x), cov(m_x) + cholinv(Λ_out)) +-end +- +-@rule KernelGCV(:y, Marginalisation) (m_x::MvNormalMeanPrecision, q_z::MvNormalMeanPrecision, meta::KernelGCVMetadata) = begin +- kernelfunction = get_kernelfn(meta) +- Λ_out = approximate_kernel_expectation(get_approximation(meta), (s) -> inv(kernelfunction(s)), q_z) +- return MvNormalMeanPrecision(mean(m_x), cholinv(cov(m_x) + cholinv(Λ_out))) +-end +diff --git a/src/rules/kernel_gcv/z.jl b/src/rules/kernel_gcv/z.jl +deleted file mode 100644 +index 5ebb4978..00000000 +--- a/src/rules/kernel_gcv/z.jl ++++ /dev/null +@@ -1,53 +0,0 @@ +-export rule +- +-@rule KernelGCV(:z, Marginalisation) (q_y_x::MvNormalMeanCovariance, meta::KernelGCVMetadata) = begin +- dims = Int64(ndims(q_y_x) / 2) +- +- m_yx = mean(q_y_x) +- cov_yx = cov(q_y_x) +- +- cov11 = @view cov_yx[1:dims, 1:dims] +- cov12 = @view cov_yx[1:dims, (dims + 1):end] +- cov21 = @view cov_yx[(dims + 1):end, 1:dims] +- cov22 = @view cov_yx[(dims + 1):end, (dims + 1):end] +- +- m1 = @view m_yx[1:dims] +- m2 = @view m_yx[(dims + 1):end] +- +- psi = cov11 + cov22 - cov12 - cov21 + (m1 - m2) * (m1 - m2)' +- +- kernelfunction = get_kernelfn(meta) +- +- logpdf = (z) -> begin +- gz = kernelfunction(z) +- -0.5 * (logdet(gz) + tr(cholinv(gz) * psi)) +- end +- +- return FnWithApproximation(logpdf, get_approximation(meta)) +-end +- +-@rule KernelGCV(:z, Marginalisation) (q_y_x::MvNormalMeanPrecision, meta::KernelGCVMetadata) = begin +- dims = Int64(ndims(q_y_x) / 2) +- +- m_yx = mean(q_y_x) +- cov_yx = cov(q_y_x) +- +- cov11 = @view cov_yx[1:dims, 1:dims] +- cov12 = @view cov_yx[1:dims, (dims + 1):end] +- cov21 = @view cov_yx[(dims + 1):end, 1:dims] +- cov22 = @view cov_yx[(dims + 1):end, (dims + 1):end] +- +- m1 = @view m_yx[1:dims] +- m2 = @view m_yx[(dims + 1):end] +- +- psi = cov11 + cov22 - cov12 - cov21 + (m1 - m2) * (m1 - m2)' +- +- kernelfunction = get_kernelfn(meta) +- +- logpdf = (z) -> begin +- gz = kernelfunction(z) +- -0.5 * (logdet(gz) + tr(cholinv(gz) * psi)) +- end +- +- return FnWithApproximation(logpdf, get_approximation(meta)) +-end +diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl +deleted file mode 100644 +index 0777efb7..00000000 +--- a/src/rules/mv_autoregressive/a.jl ++++ /dev/null +@@ -1,50 +0,0 @@ +- +-@rule MAR(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin +- order, ds = getorder(meta), getdimensionality(meta) +- F = Multivariate +- +- dim = order * ds +- +- m, V = mean_cov(q_y_x) +- +- my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) +- mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) +- Vyx = ar_slice(F, V, 1:dim, (dim + 1):(2dim)) +- +- mΛ = mean(q_Λ) +- mW = mar_transition(order, mΛ) +- +- # this should be inside MARMeta +- es = [uvector(dim, i) for i in 1:ds] +- Fs = [mask_mar(order, ds, i) for i in 1:ds] +- S = mar_shift(order, ds) +- +- # NOTE: prove that sum(Fs[i]'*((mx*mx'+Vx')*S')*mW*es[i] for i in 1:ds) == 0.0 +- D = sum(sum(es[i]' * mW * es[j] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:ds) for j in 1:ds) +- z = sum(Fs[i]' * (mx * my' + Vyx') * mW * es[i] for i in 1:ds) +- +- return MvNormalWeightedMeanPrecision(z, D) +-end +- +-@rule MAR(:a, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin +- order, ds = getorder(meta), getdimensionality(meta) +- F = Multivariate +- +- dim = order * ds +- +- my, Vy = mean_cov(q_y) +- mx, Vx = mean_cov(q_x) +- mΛ = mean(q_Λ) +- +- mW = mar_transition(order, mΛ) +- S = mar_shift(order, ds) +- +- # this should be inside MARMeta +- es = [uvector(dim, i) for i in 1:ds] +- Fs = [mask_mar(order, ds, i) for i in 1:ds] +- +- D = sum(sum(es[j]' * mW * es[i] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:ds) for j in 1:ds) +- z = sum(Fs[i]' * ((mx * mx' + Vx') * S' + mx * my') * mW * es[i] for i in 1:ds) +- +- return MvNormalWeightedMeanPrecision(z, D) +-end +diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl +deleted file mode 100644 +index 29f88cba..00000000 +--- a/src/rules/mv_autoregressive/lambda.jl ++++ /dev/null +@@ -1,53 +0,0 @@ +-@rule MAR(:Λ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::MARMeta) = begin +- order, ds = getorder(meta), getdimensionality(meta) +- F = Multivariate +- dim = order * ds +- +- ma, Va = mean_cov(q_a) +- +- mA = mar_companion_matrix(order, ds, ma) +- +- m, V = mean_cov(q_y_x) +- my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) +- mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) +- Vyx = ar_slice(F, V, 1:dim, (dim + 1):(2dim)) +- +- es = [uvector(dim, i) for i in 1:ds] +- Fs = [mask_mar(order, ds, i) for i in 1:ds] +- S = mar_shift(order, ds) +- G₁ = (my * my' + Vy)[1:ds, 1:ds] +- G₂ = ((my * mx' + Vyx) * mA')[1:ds, 1:ds] +- G₃ = transpose(G₂) +- Ex_xx = mx * mx' + Vx +- G₅ = sum(sum(es[i] * ma' * Fs[i]'Ex_xx * Fs[j] * ma * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] +- G₆ = sum(sum(es[i] * tr(Fs[i]' * Ex_xx * Fs[j] * Va) * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] +- Δ = G₁ - G₂ - G₃ + G₅ + G₆ +- +- return WishartMessage(ds + 2, Δ) +-end +- +-@rule MAR(:Λ, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::MARMeta) = +- begin +- order, ds = getorder(meta), getdimensionality(meta) +- F = Multivariate +- dim = order * ds +- +- my, Vy = mean_cov(q_y) +- mx, Vx = mean_cov(q_x) +- ma, Va = mean_cov(q_a) +- +- mA = mar_companion_matrix(order, ds, ma) +- +- es = [uvector(dim, i) for i in 1:ds] +- Fs = [mask_mar(order, ds, i) for i in 1:ds] +- S = mar_shift(order, ds) +- G₁ = (my * my' + Vy)[1:ds, 1:ds] +- G₂ = (my * mx' * mA')[1:ds, 1:ds] +- G₃ = transpose(G₂) +- Ex_xx = mx * mx' + Vx +- G₅ = sum(sum(es[i] * ma' * Fs[j]'Ex_xx * Fs[i] * ma * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] +- G₆ = sum(sum(es[i] * tr(Va * Fs[i]' * Ex_xx * Fs[j]) * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] +- Δ = G₁ - G₂ - G₃ + G₅ + G₆ +- +- return WishartMessage(ds + 2, Δ) +- end +diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl +deleted file mode 100644 +index 92a71a65..00000000 +--- a/src/rules/mv_autoregressive/marginals.jl ++++ /dev/null +@@ -1,46 +0,0 @@ +- +-@marginalrule MAR(:y_x) ( +- m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta +-) = begin +- return ar_y_x_marginal(m_y, m_x, q_a, q_Λ, meta) +-end +- +-function ar_y_x_marginal( +- m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta +-) +- order, ds = getorder(meta), getdimensionality(meta) +- F = Multivariate +- dim = order * ds +- +- ma, Va = mean_cov(q_a) +- mΛ = mean(q_Λ) +- +- mA = mar_companion_matrix(order, ds, ma) +- mW = mar_transition(getorder(meta), mΛ) +- +- b_my, b_Vy = mean_cov(m_y) +- f_mx, f_Vx = mean_cov(m_x) +- +- inv_b_Vy = cholinv(b_Vy) +- inv_f_Vx = cholinv(f_Vx) +- +- # this should be inside MARMeta +- es = [uvector(dim, i) for i in 1:ds] +- Fs = [mask_mar(order, ds, i) for i in 1:ds] +- +- Ξ = inv_f_Vx + sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) +- +- W_11 = inv_b_Vy + mW +- +- # negate_inplace!(mW * mA) +- W_12 = -(mW * mA) +- +- W_21 = -(mA' * mW) +- +- W_22 = Ξ + mA' * mW * mA +- +- W = [W_11 W_12; W_21 W_22] +- ξ = [inv_b_Vy * b_my; inv_f_Vx * f_mx] +- +- return MvNormalWeightedMeanPrecision(ξ, W) +-end +diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl +deleted file mode 100644 +index e191589a..00000000 +--- a/src/rules/mv_autoregressive/x.jl ++++ /dev/null +@@ -1,50 +0,0 @@ +- +-@rule MAR(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin +- ma, Va = mean_cov(q_a) +- my, Vy = mean_cov(m_y) +- +- mΛ = mean(q_Λ) +- +- order, ds = getorder(meta), getdimensionality(meta) +- dim = order * ds +- +- mA = mar_companion_matrix(order, ds, ma) +- mW = mar_transition(getorder(meta), mΛ) +- # this should be inside MARMeta +- es = [uvector(dim, i) for i in 1:ds] +- Fs = [mask_mar(order, ds, i) for i in 1:ds] +- +- Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) +- +- Σ₁ = Hermitian(pinv(mA) * (Vy) * pinv(mA') + pinv(mA' * mW * mA)) +- +- Ξ = (pinv(Σ₁) + Λ) +- z = pinv(Σ₁) * pinv(mA) * my +- +- return MvNormalWeightedMeanPrecision(z, Ξ) +-end +- +-@rule MAR(:x, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin +- ma, Va = mean_cov(q_a) +- my, Vy = mean_cov(q_y) +- +- mΛ = mean(q_Λ) +- +- order, ds = getorder(meta), getdimensionality(meta) +- dim = order * ds +- +- mA = mar_companion_matrix(order, ds, ma) +- mW = mar_transition(getorder(meta), mΛ) +- +- # this should be inside MARMeta +- es = [uvector(dim, i) for i in 1:ds] +- Fs = [mask_mar(order, ds, i) for i in 1:ds] +- +- Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) +- Λ₀ = Hermitian(mA' * mW * mA) +- +- Ξ = Λ₀ + Λ +- z = Λ₀ * pinv(mA) * my +- +- return MvNormalWeightedMeanPrecision(z, Ξ) +-end +diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl +deleted file mode 100644 +index b99ace9b..00000000 +--- a/src/rules/mv_autoregressive/y.jl ++++ /dev/null +@@ -1,34 +0,0 @@ +-@rule MAR(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin +- ma, Va = mean_cov(q_a) +- mx, Wx = mean_invcov(m_x) +- +- mΛ = mean(q_Λ) +- +- order, ds = getorder(meta), getdimensionality(meta) +- +- mA = mar_companion_matrix(order, ds, ma) +- mW = mar_transition(getorder(meta), mΛ) +- dim = order * ds +- # this should be inside MARMeta +- es = [uvector(dim, i) for i in 1:ds] +- Fs = [mask_mar(order, ds, i) for i in 1:ds] +- +- Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) +- +- Ξ = Λ + Wx +- z = Wx * mx +- +- Vy = mA * inv(Ξ) * mA' + inv(mW) +- my = mA * inv(Ξ) * z +- +- return MvNormalMeanCovariance(my, Vy) +-end +- +-@rule MAR(:y, Marginalisation) (q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin +- order, ds = getorder(meta), getdimensionality(meta) +- +- mA = mar_companion_matrix(order, ds, mean(q_a)) +- mW = mar_transition(getorder(meta), mean(q_Λ)) +- +- return MvNormalMeanPrecision(mA * mean(q_x), mW) +-end +diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl +index 81395dd7..5c56ab65 100644 +--- a/src/rules/prototypes.jl ++++ b/src/rules/prototypes.jl +@@ -52,11 +52,6 @@ include("gcv/w.jl") + include("gcv/marginals.jl") + include("gcv/gaussian_extension.jl") + +-include("kernel_gcv/x.jl") +-include("kernel_gcv/y.jl") +-include("kernel_gcv/z.jl") +-include("kernel_gcv/marginals.jl") +- + include("mv_normal_mean_covariance/out.jl") + include("mv_normal_mean_covariance/mean.jl") + include("mv_normal_mean_covariance/covariance.jl") +@@ -116,12 +111,6 @@ include("autoregressive/theta.jl") + include("autoregressive/gamma.jl") + include("autoregressive/marginals.jl") + +-include("mv_autoregressive/y.jl") +-include("mv_autoregressive/x.jl") +-include("mv_autoregressive/a.jl") +-include("mv_autoregressive/lambda.jl") +-include("mv_autoregressive/marginals.jl") +- + include("probit/marginals.jl") + include("probit/in.jl") + include("probit/out.jl") +diff --git a/src/variables/constant.jl b/src/variables/constant.jl +index 6cb40aed..b808acd6 100644 +--- a/src/variables/constant.jl ++++ b/src/variables/constant.jl +@@ -50,8 +50,8 @@ function constvar end + + constvar(name::Symbol, constval, collection_type::AbstractVariableCollectionType = VariableIndividual()) = ConstVariable(name, collection_type, constval, of(Message(constval, true, false, nothing)), 0) + constvar(name::Symbol, constval::Real, collection_type::AbstractVariableCollectionType = VariableIndividual()) = constvar(name, PointMass(constval), collection_type) +-constvar(name::Symbol, constval::AbstractVector, collection_type::AbstractVariableCollectionType = VariableIndividual()) = constvar(name, PointMass(constval), collection_type) +-constvar(name::Symbol, constval::AbstractMatrix, collection_type::AbstractVariableCollectionType = VariableIndividual()) = constvar(name, PointMass(constval), collection_type) ++constvar(name::Symbol, constval::AbstractArray, collection_type::AbstractVariableCollectionType = VariableIndividual()) = constvar(name, PointMass(constval), collection_type) ++constvar(name::Symbol, constval::UniformScaling, collection_type::AbstractVariableCollectionType = VariableIndividual()) = constvar(name, PointMass(constval), collection_type) + + function constvar(name::Symbol, fn::Function, length::Int) + return map(i -> constvar(name, fn(i), VariableVector(i)), 1:length) +diff --git a/src/variables/data.jl b/src/variables/data.jl +index beb05410..6600cda8 100644 +--- a/src/variables/data.jl ++++ b/src/variables/data.jl +@@ -9,15 +9,19 @@ mutable struct DataVariable{D, S} <: AbstractVariable + input_messages :: Vector{MessageObservable{AbstractMessage}} + messageout :: S + nconnected :: Int ++ isproxy :: Bool ++ isused :: Bool + end + + Base.show(io::IO, datavar::DataVariable) = print(io, "DataVariable(", indexed_name(datavar), ")") + + struct DataVariableCreationOptions{S} +- subject::S ++ subject :: S ++ isproxy :: Bool ++ isused :: Bool + end + +-Base.similar(options::DataVariableCreationOptions) = DataVariableCreationOptions(similar(options.subject)) ++Base.similar(options::DataVariableCreationOptions) = DataVariableCreationOptions(similar(options.subject), options.isproxy, options.isused) + + DataVariableCreationOptions(::Type{D}) where {D} = DataVariableCreationOptions(D, nothing) + DataVariableCreationOptions(::Type{D}, subject) where {D} = DataVariableCreationOptions(D, subject, Val(false)) +@@ -26,7 +30,7 @@ DataVariableCreationOptions(::Type{D}, subject::Nothing, allow_missing::Val{true + DataVariableCreationOptions(::Type{D}, subject::Nothing, allow_missing::Val{false}) where {D} = DataVariableCreationOptions(D, RecentSubject(Union{Message{D}}), Val(false)) + + DataVariableCreationOptions(::Type{D}, subject::S, ::Val{true}) where {D, S} = error("Error in datavar options. Custom `subject` was specified and `allow_missing` was set to true, which is disallowed. Provide a custom subject that accept missing values by itself and do no use `allow_missing` option.") +-DataVariableCreationOptions(::Type{D}, subject::S, ::Val{false}) where {D, S} = DataVariableCreationOptions{S}(subject) ++DataVariableCreationOptions(::Type{D}, subject::S, ::Val{false}) where {D, S} = DataVariableCreationOptions{S}(subject, false, false) + + """ + datavar(::Type, [ dims... ]) +@@ -72,7 +76,7 @@ datavar(name::Symbol, ::Type{D}, dims::Tuple) where {D} + datavar(name::Symbol, ::Type{D}, dims::Vararg{Int}) where {D} = datavar(DataVariableCreationOptions(D), name, D, dims) + + datavar(options::DataVariableCreationOptions{S}, name::Symbol, ::Type{D}, collection_type::AbstractVariableCollectionType = VariableIndividual()) where {S, D} = +- DataVariable{D, S}(name, collection_type, MarginalObservable(), Vector{MessageObservable{AbstractMessage}}(), options.subject, 0) ++ DataVariable{D, S}(name, collection_type, MarginalObservable(), Vector{MessageObservable{AbstractMessage}}(), options.subject, 0, options.isproxy, options.isused) + + function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, length::Int) where {D} + return map(i -> datavar(similar(options), name, D, VariableVector(i)), 1:length) +@@ -93,12 +97,13 @@ Base.eltype(::DataVariable{D}) where {D} = D + + degree(datavar::DataVariable) = nconnected(datavar) + name(datavar::DataVariable) = datavar.name +-proxy_variables(datavar::DataVariable) = nothing ++proxy_variables(datavar::DataVariable) = nothing # not related to isproxy + collection_type(datavar::DataVariable) = datavar.collection_type + isconnected(datavar::DataVariable) = datavar.nconnected !== 0 + nconnected(datavar::DataVariable) = datavar.nconnected + +-isproxy(::DataVariable) = false ++isproxy(datavar::DataVariable) = datavar.isproxy ++isused(datavar::DataVariable) = datavar.isused + + israndom(::DataVariable) = false + israndom(::AbstractArray{<:DataVariable}) = false +@@ -117,7 +122,7 @@ function Base.getindex(datavar::DataVariable, i...) + error("Variable $(indexed_name(datavar)) has been indexed with `[$(join(i, ','))]`. Direct indexing of `data` variables is not allowed.") + end + +-getlastindex(::DataVariable) = 1 ++getlastindex(datavar::DataVariable) = degree(datavar) + 1 + + messageout(datavar::DataVariable, ::Int) = datavar.messageout + messagein(datavar::DataVariable, ::Int) = error("It is not possible to get a reference for inbound message for datavar") +@@ -168,16 +173,18 @@ _getmarginal(datavar::DataVariable) = datavar.messageout |> map(Mar + _setmarginal!(datavar::DataVariable, observable) = error("It is not possible to set a marginal stream for `DataVariable`") + _makemarginal(datavar::DataVariable) = error("It is not possible to make marginal stream for `DataVariable`") + +-# Extension for _getmarginal +-function Rocket.getrecent(proxy::ProxyObservable{<:Marginal, S, M}) where {S <: Rocket.RecentSubjectInstance, D, M <: Rocket.MapProxy{D, typeof(as_marginal)}} +- return as_marginal(Rocket.getrecent(proxy.proxied_source)) +-end +- + setanonymous!(::DataVariable, ::Bool) = nothing + +-function setmessagein!(datavar::DataVariable, ::Int, messagein) +- datavar.nconnected += 1 +- push!(datavar.input_messages, messagein) ++function setmessagein!(datavar::DataVariable, index::Int, messagein) ++ if index === (degree(datavar) + 1) ++ push!(datavar.input_messages, messagein) ++ datavar.nconnected += 1 ++ datavar.isused = true ++ else ++ error( ++ "Inconsistent state in setmessagein! function for data variable $(datavar). `index` should be equal to `degree(datavar) + 1 = $(degree(datavar) + 1)`, $(index) is given instead" ++ ) ++ end + return nothing + end + +diff --git a/src/variables/variable.jl b/src/variables/variable.jl +index 5d2be480..bf990a7b 100644 +--- a/src/variables/variable.jl ++++ b/src/variables/variable.jl +@@ -147,7 +147,7 @@ track of `proxy_variables`. During the first call of `get_factorisation_referenc + 2. if yes we pass it futher to the `unchecked` version of the function + 2.1 `unchecked` version return immediatelly if there is only one proxy var (see bullet 1) + 2.2 in case of multiple proxy vars we filter only `RandomVariable` and call `checked` version of the function +-3. `checked` version of the function return immediatelly if there is only one proxy random variable left, if there are multuple proxy random vars we throw an error as this case is ambigous for factorisation constrains specification ++3. `checked` version of the function return immediatelly if there is only one proxy random variable left, if there are multiple proxy random vars we throw an error as this case is ambigous for factorisation constraints specification + + This function is a part of private API and should not be used explicitly. + """ +diff --git a/test/approximations/test_cvi.jl b/test/approximations/test_cvi.jl +index fbc9f3b7..cf6a1df0 100644 +--- a/test/approximations/test_cvi.jl ++++ b/test/approximations/test_cvi.jl +@@ -80,8 +80,8 @@ end + rng = StableRNG(42) + + tests = ( +- (method = CVI(StableRNG(42), 1, 1000, Descent(0.01), ForwardDiffGrad(), 1, Val(true), false), tol = 5e-1), +- (method = CVI(StableRNG(42), 1, 1000, Descent(0.01), ZygoteGrad(), 1, Val(true), false), tol = 5e-1) ++ (method = CVI(StableRNG(42), 1, 1000, Descent(0.01), ForwardDiffGrad(), 10, Val(true), false), tol = 5e-1), ++ (method = CVI(StableRNG(42), 1, 1000, Descent(0.01), ZygoteGrad(), 10, Val(true), false), tol = 5e-1) + ) + + # Check several prods against their analytical solutions +@@ -122,9 +122,16 @@ end + + b1 = Bernoulli(logistic(randn(rng))) + b2 = Bernoulli(logistic(randn(rng))) +- b_analitical = prod(ProdAnalytical(), b1, b2) ++ b_analytical = prod(ProdAnalytical(), b1, b2) + b_cvi = prod(test[:method], b1, b1) +- @test isapprox(mean(b_analitical), mean(b_cvi), atol = test[:tol]) ++ @test isapprox(mean(b_analytical), mean(b_cvi), atol = test[:tol]) ++ ++ beta_1 = Beta(abs(randn(rng)) + 1, abs(randn(rng)) + 1) ++ beta_2 = Beta(abs(randn(rng)) + 1, abs(randn(rng)) + 1) ++ ++ beta_analytical = prod(ProdAnalytical(), beta_1, beta_2) ++ beta_cvi = prod(test[:method], beta_1, beta_2) ++ @test isapprox(mean(beta_analytical), mean(beta_cvi), atol = test[:tol]) + end + end + +diff --git a/test/distributions/test_bernoulli.jl b/test/distributions/test_bernoulli.jl +index 2abc4968..8a469cd6 100644 +--- a/test/distributions/test_bernoulli.jl ++++ b/test/distributions/test_bernoulli.jl +@@ -26,6 +26,18 @@ using ReactiveMP: compute_logscale + @test prod(ProdAnalytical(), Bernoulli(0.78), Bernoulli(0.05)) ≈ Bernoulli(0.1572580645161291) + end + ++ @testset "prod Bernoulli-Categorical" begin ++ @test prod(ProdAnalytical(), Bernoulli(0.5), Categorical([1.0])) == Categorical([1.0, 0.0]) ++ @test prod(ProdAnalytical(), Bernoulli(0.6), Categorical([0.7, 0.3])) == Categorical([0.6086956521739131, 0.391304347826087]) ++ @test prod(ProdAnalytical(), Bernoulli(0.8), Categorical([0.2, 0.4, 0.4])) == Categorical([0.11111111111111108, 0.8888888888888888, 0.0]) ++ end ++ ++ @testset "prod Categorical-Bernoulli" begin ++ @test prod(ProdAnalytical(), Categorical([1.0]), Bernoulli(0.5)) == Categorical([1.0, 0.0]) ++ @test prod(ProdAnalytical(), Categorical([0.7, 0.3]), Bernoulli(0.6)) == Categorical([0.6086956521739131, 0.391304347826087]) ++ @test prod(ProdAnalytical(), Categorical([0.2, 0.4, 0.4]), Bernoulli(0.8)) == Categorical([0.11111111111111108, 0.8888888888888888, 0.0]) ++ end ++ + @testset "probvec" begin + @test probvec(Bernoulli(0.5)) === (0.5, 0.5) + @test probvec(Bernoulli(0.3)) === (0.7, 0.3) +diff --git a/test/distributions/test_beta.jl b/test/distributions/test_beta.jl +index 0510f61c..14b4f5ab 100644 +--- a/test/distributions/test_beta.jl ++++ b/test/distributions/test_beta.jl +@@ -6,6 +6,7 @@ using Distributions + using Random + + import ReactiveMP: mirrorlog ++import SpecialFunctions: loggamma + + @testset "Beta" begin + +@@ -37,6 +38,38 @@ import ReactiveMP: mirrorlog + @test mean(mirrorlog, Beta(0.1, 0.3)) ≈ -0.9411396776150167 + @test mean(mirrorlog, Beta(4.5, 0.3)) ≈ -4.963371962929249 + end ++ ++ @testset "BetaNaturalParameters" begin ++ @testset "Constructor" begin ++ for i in 0:10, j in 0:10 ++ @test convert(Distribution, BetaNaturalParameters(i, j)) == Beta(i + 1, j + 1) ++ ++ @test convert(BetaNaturalParameters, i, j) == BetaNaturalParameters(i, j) ++ @test convert(BetaNaturalParameters, [i, j]) == BetaNaturalParameters(i, j) ++ end ++ end ++ ++ @testset "lognormalizer" begin ++ @test lognormalizer(BetaNaturalParameters(0, 0)) ≈ 0 ++ @test lognormalizer(BetaNaturalParameters(1, 1)) ≈ -loggamma(4) ++ end ++ ++ @testset "logpdf" begin ++ for i in 0:10, j in 0:10 ++ @test logpdf(BetaNaturalParameters(i, j), 0.01) ≈ logpdf(Beta(i + 1, j + 1), 0.01) ++ @test logpdf(BetaNaturalParameters(i, j), 0.5) ≈ logpdf(Beta(i + 1, j + 1), 0.5) ++ end ++ end ++ ++ @testset "isproper" begin ++ for i in 0:10 ++ @test isproper(BetaNaturalParameters(i, i)) === true ++ end ++ for i in 1:10 ++ @test isproper(BetaNaturalParameters(-i, -i)) === false ++ end ++ end ++ end + end + + end +diff --git a/test/distributions/test_mv_normal_mean_covariance.jl b/test/distributions/test_mv_normal_mean_covariance.jl +index ef9af84d..55b4d4dd 100644 +--- a/test/distributions/test_mv_normal_mean_covariance.jl ++++ b/test/distributions/test_mv_normal_mean_covariance.jl +@@ -14,6 +14,13 @@ using Distributions + @test MvNormalMeanCovariance([1, 2]) == MvNormalMeanCovariance([1.0, 2.0], [1.0, 1.0]) + @test MvNormalMeanCovariance([1.0f0, 2.0f0]) == MvNormalMeanCovariance([1.0f0, 2.0f0], [1.0f0, 1.0f0]) + ++ # uniformscaling ++ @test MvNormalMeanCovariance([1, 2], I) == MvNormalMeanCovariance([1, 2], Diagonal([1, 1])) ++ @test MvNormalMeanCovariance([1, 2], 6 * I) == MvNormalMeanCovariance([1, 2], Diagonal([6, 6])) ++ @test MvNormalMeanCovariance([1.0, 2.0], I) == MvNormalMeanCovariance([1.0, 2.0], Diagonal([1.0, 1.0])) ++ @test MvNormalMeanCovariance([1.0, 2.0], 6 * I) == MvNormalMeanCovariance([1.0, 2.0], Diagonal([6.0, 6.0])) ++ @test MvNormalMeanCovariance([1, 2], 6.0 * I) == MvNormalMeanCovariance([1.0, 2.0], Diagonal([6.0, 6.0])) ++ + @test eltype(MvNormalMeanCovariance([1.0, 1.0])) === Float64 + @test eltype(MvNormalMeanCovariance([1.0, 1.0], [1.0, 1.0])) === Float64 + @test eltype(MvNormalMeanCovariance([1, 1])) === Float64 +@@ -91,6 +98,14 @@ using Distributions + dist = MvNormalMeanCovariance(μ, Σ) + + @test prod(ProdAnalytical(), dist, dist) ≈ MvNormalWeightedMeanPrecision([2.0, 2.0, 2.0], diagm([2.0, 1.0, 2 / 3])) ++ ++ # diagonal covariance matrix/uniformscaling ++ @test prod(ProdAnalytical(), MvNormalMeanCovariance([-1, -1], [2 0; 0 2]), MvNormalMeanCovariance([1, 1], Diagonal([2, 4]))) ≈ ++ MvNormalWeightedMeanPrecision([0, -1 / 4], [1, 3 / 4]) ++ @test prod(ProdAnalytical(), MvNormalMeanCovariance([-1, -1], [2, 2]), MvNormalMeanCovariance([1, 1], Diagonal([2, 4]))) ≈ ++ MvNormalWeightedMeanPrecision([0, -1 / 4], [1, 3 / 4]) ++ @test prod(ProdAnalytical(), MvNormalMeanCovariance([-1, -1], 2 * I), MvNormalMeanCovariance([1, 1], Diagonal([2, 4]))) ≈ ++ MvNormalWeightedMeanPrecision([0, -1 / 4], [1, 3 / 4]) + end + + @testset "Primitive types conversion" begin +diff --git a/test/distributions/test_mv_normal_mean_precision.jl b/test/distributions/test_mv_normal_mean_precision.jl +index 97afa3b5..4f00fd2e 100644 +--- a/test/distributions/test_mv_normal_mean_precision.jl ++++ b/test/distributions/test_mv_normal_mean_precision.jl +@@ -14,6 +14,13 @@ using Distributions + @test MvNormalMeanPrecision([1, 2]) == MvNormalMeanPrecision([1.0, 2.0], [1.0, 1.0]) + @test MvNormalMeanPrecision([1.0f0, 2.0f0]) == MvNormalMeanPrecision([1.0f0, 2.0f0], [1.0f0, 1.0f0]) + ++ # uniformscaling ++ @test MvNormalMeanPrecision([1, 2], I) == MvNormalMeanPrecision([1, 2], Diagonal([1, 1])) ++ @test MvNormalMeanPrecision([1, 2], 6 * I) == MvNormalMeanPrecision([1, 2], Diagonal([6, 6])) ++ @test MvNormalMeanPrecision([1.0, 2.0], I) == MvNormalMeanPrecision([1.0, 2.0], Diagonal([1.0, 1.0])) ++ @test MvNormalMeanPrecision([1.0, 2.0], 6 * I) == MvNormalMeanPrecision([1.0, 2.0], Diagonal([6.0, 6.0])) ++ @test MvNormalMeanPrecision([1, 2], 6.0 * I) == MvNormalMeanPrecision([1.0, 2.0], Diagonal([6.0, 6.0])) ++ + @test eltype(MvNormalMeanPrecision([1.0, 1.0])) === Float64 + @test eltype(MvNormalMeanPrecision([1.0, 1.0], [1.0, 1.0])) === Float64 + @test eltype(MvNormalMeanPrecision([1, 1])) === Float64 +@@ -91,6 +98,11 @@ using Distributions + dist = MvNormalMeanPrecision(μ, Λ) + + @test prod(ProdAnalytical(), dist, dist) ≈ MvNormalWeightedMeanPrecision([2.0, 2.0, 2.0], diagm([2.0, 1.0, 2 / 3])) ++ ++ # diagonal covariance matrix/uniformscaling ++ @test prod(ProdAnalytical(), MvNormalMeanPrecision([-1, -1], [2 0; 0 2]), MvNormalMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ MvNormalWeightedMeanPrecision([0, 2], [4, 6]) ++ @test prod(ProdAnalytical(), MvNormalMeanPrecision([-1, -1], [2, 2]), MvNormalMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ MvNormalWeightedMeanPrecision([0, 2], [4, 6]) ++ @test prod(ProdAnalytical(), MvNormalMeanPrecision([-1, -1], 2 * I), MvNormalMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ MvNormalWeightedMeanPrecision([0, 2], [4, 6]) + end + + @testset "Primitive types conversion" begin +diff --git a/test/distributions/test_mv_normal_weighted_mean_precision.jl b/test/distributions/test_mv_normal_weighted_mean_precision.jl +index 681adf6d..ee28cd68 100644 +--- a/test/distributions/test_mv_normal_weighted_mean_precision.jl ++++ b/test/distributions/test_mv_normal_weighted_mean_precision.jl +@@ -14,6 +14,13 @@ using Distributions + @test MvNormalWeightedMeanPrecision([1, 2]) == MvNormalWeightedMeanPrecision([1.0, 2.0], [1.0, 1.0]) + @test MvNormalWeightedMeanPrecision([1.0f0, 2.0f0]) == MvNormalWeightedMeanPrecision([1.0f0, 2.0f0], [1.0f0, 1.0f0]) + ++ # uniformscaling ++ @test MvNormalWeightedMeanPrecision([1, 2], I) == MvNormalWeightedMeanPrecision([1, 2], Diagonal([1, 1])) ++ @test MvNormalWeightedMeanPrecision([1, 2], 6 * I) == MvNormalWeightedMeanPrecision([1, 2], Diagonal([6, 6])) ++ @test MvNormalWeightedMeanPrecision([1.0, 2.0], I) == MvNormalWeightedMeanPrecision([1.0, 2.0], Diagonal([1.0, 1.0])) ++ @test MvNormalWeightedMeanPrecision([1.0, 2.0], 6 * I) == MvNormalWeightedMeanPrecision([1.0, 2.0], Diagonal([6.0, 6.0])) ++ @test MvNormalWeightedMeanPrecision([1, 2], 6.0 * I) == MvNormalWeightedMeanPrecision([1.0, 2.0], Diagonal([6.0, 6.0])) ++ + @test eltype(MvNormalWeightedMeanPrecision([1.0, 1.0])) === Float64 + @test eltype(MvNormalWeightedMeanPrecision([1.0, 1.0], [1.0, 1.0])) === Float64 + @test eltype(MvNormalWeightedMeanPrecision([1, 1])) === Float64 +@@ -91,6 +98,14 @@ using Distributions + dist = MvNormalWeightedMeanPrecision(xi, Λ) + + @test prod(ProdAnalytical(), dist, dist) ≈ MvNormalWeightedMeanPrecision([0.40, 6.00, 8.00], [3.00 -0.20 0.20; -0.20 3.60 0.00; 0.20 0.00 7.00]) ++ ++ # diagonal covariance matrix/uniformscaling ++ @test prod(ProdAnalytical(), MvNormalWeightedMeanPrecision([-1, -1], [2 0; 0 2]), MvNormalWeightedMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ ++ MvNormalWeightedMeanPrecision([0, 0], [4, 6]) ++ @test prod(ProdAnalytical(), MvNormalWeightedMeanPrecision([-1, -1], [2, 2]), MvNormalWeightedMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ ++ MvNormalWeightedMeanPrecision([0, 0], [4, 6]) ++ @test prod(ProdAnalytical(), MvNormalWeightedMeanPrecision([-1, -1], 2 * I), MvNormalWeightedMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ ++ MvNormalWeightedMeanPrecision([0, 0], [4, 6]) + end + + @testset "Primitive types conversion" begin +diff --git a/test/distributions/test_normal_mean_precision.jl b/test/distributions/test_normal_mean_precision.jl +index 3db706ac..006ef088 100644 +--- a/test/distributions/test_normal_mean_precision.jl ++++ b/test/distributions/test_normal_mean_precision.jl +@@ -3,6 +3,8 @@ module NormalMeanPrecisionTest + using Test + using ReactiveMP + ++using LinearAlgebra: I ++ + @testset "NormalMeanPrecision" begin + @testset "Constructor" begin + @test NormalMeanPrecision <: NormalDistributionsFamily +@@ -20,6 +22,13 @@ using ReactiveMP + @test NormalMeanPrecision(1.0f0, 2) == NormalMeanPrecision{Float32}(1.0f0, 2.0f0) + @test NormalMeanPrecision(1.0f0, 2.0) == NormalMeanPrecision{Float64}(1.0, 2.0) + ++ # uniformscaling ++ @test NormalMeanPrecision(2, I) == NormalMeanPrecision(2, 1) ++ @test NormalMeanPrecision(2, 6 * I) == NormalMeanPrecision(2, 6) ++ @test NormalMeanPrecision(2.0, I) == NormalMeanPrecision(2.0, 1.0) ++ @test NormalMeanPrecision(2.0, 6 * I) == NormalMeanPrecision(2.0, 6.0) ++ @test NormalMeanPrecision(2, 6.0 * I) == NormalMeanPrecision(2.0, 6.0) ++ + @test eltype(NormalMeanPrecision()) === Float64 + @test eltype(NormalMeanPrecision(0.0)) === Float64 + @test eltype(NormalMeanPrecision(0.0, 1.0)) === Float64 +diff --git a/test/distributions/test_normal_mean_variance.jl b/test/distributions/test_normal_mean_variance.jl +index 27b27389..a17ff8a4 100644 +--- a/test/distributions/test_normal_mean_variance.jl ++++ b/test/distributions/test_normal_mean_variance.jl +@@ -3,6 +3,8 @@ module NormalMeanVarianceTest + using Test + using ReactiveMP + ++using LinearAlgebra: I ++ + @testset "NormalMeanVariance" begin + @testset "Constructor" begin + @test NormalMeanVariance <: NormalDistributionsFamily +@@ -20,6 +22,13 @@ using ReactiveMP + @test NormalMeanVariance(1.0f0, 2) == NormalMeanVariance{Float32}(1.0f0, 2.0f0) + @test NormalMeanVariance(1.0f0, 2.0) == NormalMeanVariance{Float64}(1.0, 2.0) + ++ # uniformscaling ++ @test NormalMeanVariance(2, I) == NormalMeanVariance(2, 1) ++ @test NormalMeanVariance(2, 6 * I) == NormalMeanVariance(2, 6) ++ @test NormalMeanVariance(2.0, I) == NormalMeanVariance(2.0, 1.0) ++ @test NormalMeanVariance(2.0, 6 * I) == NormalMeanVariance(2.0, 6.0) ++ @test NormalMeanVariance(2, 6.0 * I) == NormalMeanVariance(2.0, 6.0) ++ + @test eltype(NormalMeanVariance()) === Float64 + @test eltype(NormalMeanVariance(0.0)) === Float64 + @test eltype(NormalMeanVariance(0.0, 1.0)) === Float64 +diff --git a/test/distributions/test_normal_weighted_mean_precision.jl b/test/distributions/test_normal_weighted_mean_precision.jl +index ace5ebfd..5a90d97e 100644 +--- a/test/distributions/test_normal_weighted_mean_precision.jl ++++ b/test/distributions/test_normal_weighted_mean_precision.jl +@@ -3,6 +3,8 @@ module NormalWeightedMeanPrecisionTest + using Test + using ReactiveMP + ++using LinearAlgebra: I ++ + @testset "NormalWeightedMeanPrecision" begin + @testset "Constructor" begin + @test NormalWeightedMeanPrecision <: NormalDistributionsFamily +@@ -19,6 +21,13 @@ using ReactiveMP + @test NormalWeightedMeanPrecision(1.0f0, 2.0f0) == NormalWeightedMeanPrecision{Float32}(1.0f0, 2.0f0) + @test NormalWeightedMeanPrecision(1.0f0, 2.0) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) + ++ # uniformscaling ++ @test NormalWeightedMeanPrecision(2, I) == NormalWeightedMeanPrecision(2, 1) ++ @test NormalWeightedMeanPrecision(2, 6 * I) == NormalWeightedMeanPrecision(2, 6) ++ @test NormalWeightedMeanPrecision(2.0, I) == NormalWeightedMeanPrecision(2.0, 1.0) ++ @test NormalWeightedMeanPrecision(2.0, 6 * I) == NormalWeightedMeanPrecision(2.0, 6.0) ++ @test NormalWeightedMeanPrecision(2, 6.0 * I) == NormalWeightedMeanPrecision(2.0, 6.0) ++ + @test eltype(NormalWeightedMeanPrecision()) === Float64 + @test eltype(NormalWeightedMeanPrecision(0.0)) === Float64 + @test eltype(NormalWeightedMeanPrecision(0.0, 1.0)) === Float64 +diff --git a/test/distributions/test_pointmass.jl b/test/distributions/test_pointmass.jl +index 8f45a481..1a2e926b 100644 +--- a/test/distributions/test_pointmass.jl ++++ b/test/distributions/test_pointmass.jl +@@ -5,6 +5,7 @@ using ReactiveMP + using Distributions + using Random + using SpecialFunctions ++using LinearAlgebra: UniformScaling, I + + import ReactiveMP: CountingReal, tiny, huge + import ReactiveMP.MacroHelpers: @test_inferred +@@ -163,6 +164,47 @@ import ReactiveMP: xtlog, mirrorlog + @test @test_inferred(AbstractMatrix{T}, mean(loggamma, dist)) == loggamma.(matrix) + end + end ++ ++ @testset "UniformScaling-based PointMass" begin ++ for T in (Float16, Float32, Float64, BigFloat) ++ matrix = convert(T, 5) * I ++ dist = PointMass(matrix) ++ ++ @test variate_form(dist) === Matrixvariate ++ @test dist[2, 1] == zero(T) ++ @test dist[3, 1] == zero(T) ++ @test dist[3, 3] === matrix[3, 3] ++ ++ @test pdf(dist, matrix) == one(T) ++ @test pdf(dist, matrix + convert(T, tiny) * I) == zero(T) ++ @test pdf(dist, matrix - convert(T, tiny) * I) == zero(T) ++ ++ @test logpdf(dist, matrix) == zero(T) ++ @test logpdf(dist, matrix + convert(T, tiny) * I) == convert(T, -Inf) ++ @test logpdf(dist, matrix - convert(T, tiny) * I) == convert(T, -Inf) ++ ++ @test_throws MethodError insupport(dist, one(T)) ++ @test_throws MethodError insupport(dist, ones(T, 2)) ++ @test_throws MethodError pdf(dist, one(T)) ++ @test_throws MethodError pdf(dist, ones(T, 2)) ++ @test_throws MethodError logpdf(dist, one(T)) ++ @test_throws MethodError logpdf(dist, ones(T, 2)) ++ ++ @test (@inferred entropy(dist)) == CountingReal(eltype(dist), -1) ++ ++ @test mean(dist) == matrix ++ @test mode(dist) == matrix ++ @test var(dist) == zero(T) * I ++ @test std(dist) == zero(T) * I ++ ++ @test_throws ErrorException cov(dist) ++ @test_throws ErrorException precision(dist) ++ ++ @test_throws ErrorException probvec(dist) ++ @test mean(inv, dist) ≈ inv(matrix) ++ @test mean(cholinv, dist) ≈ inv(matrix) ++ end ++ end + end + + end +diff --git a/test/rules/bernoulli/test_marginals.jl b/test/rules/bernoulli/test_marginals.jl +index da6be1e2..d1b835f8 100644 +--- a/test/rules/bernoulli/test_marginals.jl ++++ b/test/rules/bernoulli/test_marginals.jl +@@ -13,5 +13,12 @@ import ReactiveMP: @test_marginalrules + (input = (m_out = PointMass(0.0), m_p = Beta(1.0, 2.0)), output = (out = PointMass(0.0), p = Beta(1.0, 3.0))) + ] + end ++ @testset "out_p: (m_out::Bernoulli, m_p::PointMass)" begin ++ @test_marginalrules [with_float_conversions = true] Bernoulli(:out_p) [ ++ (input = (m_out = Bernoulli(0.8), m_p = PointMass(1.0)), output = (out = Bernoulli(1.0), p = PointMass(1.0))), ++ (input = (m_out = Bernoulli(0.2), m_p = PointMass(1.0)), output = (out = Bernoulli(1.0), p = PointMass(1.0))), ++ (input = (m_out = Bernoulli(0.2), m_p = PointMass(0.0)), output = (out = Bernoulli(0.0), p = PointMass(0.0))) ++ ] ++ end + end + end +diff --git a/test/rules/bernoulli/test_p.jl b/test/rules/bernoulli/test_p.jl +index c6c34c34..f5b08bf7 100644 +--- a/test/rules/bernoulli/test_p.jl ++++ b/test/rules/bernoulli/test_p.jl +@@ -19,7 +19,7 @@ import ReactiveMP: @test_rules + end + + @testset "Variational Message Passing: (q_out::DiscreteNonParametric)" begin +- # `with_falot_conversions = false` here is because apparently ++ # `with_float_conversions = false` here is because apparently + # BigFloat(0.7) + BigFloat(0.3) != BigFloat(1.0) + @test_rules [with_float_conversions = false] Bernoulli(:p, Marginalisation) [ + (input = (q_out = Categorical([0.0, 1.0]),), output = Beta(2.0, 1.0)), (input = (q_out = Categorical([0.7, 0.3]),), output = Beta(13 / 10, 17 / 10)) +diff --git a/test/rules/categorical/test_marginals.jl b/test/rules/categorical/test_marginals.jl +new file mode 100644 +index 00000000..43363b5c +--- /dev/null ++++ b/test/rules/categorical/test_marginals.jl +@@ -0,0 +1,25 @@ ++module RulesCategoricalMarginalsTest ++ ++using Test ++using ReactiveMP ++using Random ++using LinearAlgebra ++import ReactiveMP: @test_marginalrules ++ ++@testset "marginalrules:Categorical" begin ++ @testset "out_p: (m_out::PointMass, m_p::Dirichlet)" begin ++ @test_marginalrules [with_float_conversions = true] Categorical(:out_p) [ ++ (input = (m_out = PointMass([0.0, 1.0]), m_p = Dirichlet([2.0, 1.0])), output = (out = PointMass([0.0, 1.0]), p = Dirichlet([2.0, 2.0]))), ++ (input = (m_out = PointMass([0.0, 1.0]), m_p = Dirichlet([4.0, 2.0])), output = (out = PointMass([0.0, 1.0]), p = Dirichlet([4.0, 3.0]))), ++ (input = (m_out = PointMass([1.0, 0.0]), m_p = Dirichlet([1.0, 2.0])), output = (out = PointMass([1.0, 0.0]), p = Dirichlet([2.0, 2.0]))) ++ ] ++ end ++ @testset "out_p: (m_out::Categorical, m_p::PointMass)" begin ++ @test_marginalrules [with_float_conversions = false] Categorical(:out_p) [ ++ (input = (m_out = Categorical([0.2, 0.8]), m_p = PointMass([0.0, 1.0])), output = (out = Categorical(normalize([tiny, 0.8], 1)), p = PointMass([0.0, 1.0]))), ++ (input = (m_out = Categorical([0.8, 0.2]), m_p = PointMass([0.0, 1.0])), output = (out = Categorical(normalize([tiny, 0.2], 1)), p = PointMass([0.0, 1.0]))), ++ (input = (m_out = Categorical([0.8, 0.2]), m_p = PointMass([1.0, 0.0])), output = (out = Categorical(normalize([0.8, tiny], 1)), p = PointMass([1.0, 0.0]))) ++ ] ++ end ++end ++end +diff --git a/test/rules/categorical/test_out.jl b/test/rules/categorical/test_out.jl +new file mode 100644 +index 00000000..929ffbc2 +--- /dev/null ++++ b/test/rules/categorical/test_out.jl +@@ -0,0 +1,27 @@ ++module RulesCategoricalOutTest ++ ++using Test ++using ReactiveMP ++using Random ++import ReactiveMP: @test_rules ++ ++@testset "rules:Categorical:out" begin ++ @testset "Belief Propagation: (m_p::PointMass)" begin ++ @test_rules [with_float_conversions = false] Categorical(:out, Marginalisation) [ ++ (input = (m_p = PointMass([0.0, 1.0]),), output = Categorical([0.0, 1.0])), (input = (m_p = PointMass([0.8, 0.2]),), output = Categorical([0.8, 0.2])) ++ ] ++ end ++ ++ @testset "Variational Message Passing: (q_p::PointMass)" begin ++ @test_rules [with_float_conversions = false] Categorical(:out, Marginalisation) [ ++ (input = (q_p = PointMass([0.0, 1.0]),), output = Categorical([0.0, 1.0])), (input = (q_p = PointMass([0.7, 0.3]),), output = Categorical([0.7, 0.3])) ++ ] ++ end ++ ++ @testset "Variational Message Passing: (q_p::Dirichlet)" begin ++ @test_rules [with_float_conversions = false] Categorical(:out, Marginalisation) [ ++ (input = (q_p = Dirichlet([1.0, 1.0]),), output = Categorical([0.5, 0.5])), (input = (q_p = Dirichlet([0.2, 0.2]),), output = Categorical([0.5, 0.5])) ++ ] ++ end ++end ++end +diff --git a/test/rules/categorical/test_p.jl b/test/rules/categorical/test_p.jl +new file mode 100644 +index 00000000..55e396ca +--- /dev/null ++++ b/test/rules/categorical/test_p.jl +@@ -0,0 +1,21 @@ ++module RulesCategoricalPTest ++ ++using Test ++using ReactiveMP ++using Random ++import ReactiveMP: @test_rules ++ ++@testset "rules:Categorical:p" begin ++ @testset "Variational Message Passing: (q_out::PointMass)" begin ++ @test_rules [with_float_conversions = true] Categorical(:p, Marginalisation) [ ++ (input = (q_out = PointMass([0.0, 1.0]),), output = Dirichlet([1.0, 2.0])), (input = (q_out = PointMass([0.8, 0.2]),), output = Dirichlet([9 / 5, 12 / 10])) ++ ] ++ end ++ ++ @testset "Variational Message Passing: (q_out::Categorical)" begin ++ @test_rules [with_float_conversions = false] Categorical(:p, Marginalisation) [ ++ (input = (q_out = Categorical([0.0, 1.0]),), output = Dirichlet([1.0, 2.0])), (input = (q_out = Categorical([0.7, 0.3]),), output = Dirichlet([17 / 10, 13 / 10])) ++ ] ++ end ++end ++end +diff --git a/test/rules/dirichlet/test_marginals.jl b/test/rules/dirichlet/test_marginals.jl +new file mode 100644 +index 00000000..51682227 +--- /dev/null ++++ b/test/rules/dirichlet/test_marginals.jl +@@ -0,0 +1,17 @@ ++module RulesDirichletMarginalsTest ++ ++using Test ++using ReactiveMP ++using Random ++import ReactiveMP: @test_marginalrules ++ ++@testset "marginalrules:Dirichlet" begin ++ @testset "out_a: (m_out::Dirichlet, m_a::PointMass)" begin ++ @test_marginalrules [with_float_conversions = true] Dirichlet(:out_a) [ ++ (input = (m_out = Dirichlet([1.0, 2.0]), m_a = PointMass([0.2, 1.0])), output = (out = Dirichlet([0.2, 2.0]), a = PointMass([0.2, 1.0]))), ++ (input = (m_out = Dirichlet([2.0, 2.0]), m_a = PointMass([2.0, 0.5])), output = (out = Dirichlet([3.0, 1.5]), a = PointMass([2.0, 0.5]))), ++ (input = (m_out = Dirichlet([2.0, 3.0]), m_a = PointMass([3.0, 1.0])), output = (out = Dirichlet([4.0, 3.0]), a = PointMass([3.0, 1.0]))) ++ ] ++ end ++end ++end +diff --git a/test/rules/dirichlet/test_out.jl b/test/rules/dirichlet/test_out.jl +new file mode 100644 +index 00000000..874a5885 +--- /dev/null ++++ b/test/rules/dirichlet/test_out.jl +@@ -0,0 +1,25 @@ ++module RulesDirichletOutTest ++ ++using Test ++using ReactiveMP ++using Random ++import ReactiveMP: @test_rules ++ ++@testset "rules:Dirichlet:out" begin ++ @testset "Belief Propagation: (m_a::PointMass)" begin ++ @test_rules [with_float_conversions = true] Dirichlet(:out, Marginalisation) [ ++ (input = (m_a = PointMass([0.2, 1.0]),), output = Dirichlet([0.2, 1.0])), ++ (input = (m_a = PointMass([2.0, 0.5]),), output = Dirichlet([2.0, 0.5])), ++ (input = (m_a = PointMass([3.0, 1.0]),), output = Dirichlet([3.0, 1.0])) ++ ] ++ end ++ ++ @testset "Variational Message Passing: (q_a::PointMass)" begin ++ @test_rules [with_float_conversions = true] Dirichlet(:out, Marginalisation) [ ++ (input = (q_a = PointMass([0.2, 1.0]),), output = Dirichlet([0.2, 1.0])), ++ (input = (q_a = PointMass([2.0, 0.5]),), output = Dirichlet([2.0, 0.5])), ++ (input = (q_a = PointMass([3.0, 1.0]),), output = Dirichlet([3.0, 1.0])) ++ ] ++ end ++end ++end +diff --git a/test/runtests.jl b/test/runtests.jl +index 2e50971e..0c2fa1e2 100644 +--- a/test/runtests.jl ++++ b/test/runtests.jl +@@ -312,6 +312,10 @@ end + addtests(testrunner, "rules/beta/test_out.jl") + addtests(testrunner, "rules/beta/test_marginals.jl") + ++ addtests(testrunner, "rules/categorical/test_out.jl") ++ addtests(testrunner, "rules/categorical/test_p.jl") ++ addtests(testrunner, "rules/categorical/test_marginals.jl") ++ + addtests(testrunner, "rules/delta/unscented/test_out.jl") + addtests(testrunner, "rules/delta/unscented/test_in.jl") + addtests(testrunner, "rules/delta/unscented/test_marginals.jl") +@@ -324,6 +328,9 @@ end + addtests(testrunner, "rules/delta/cvi/test_marginals.jl") + addtests(testrunner, "rules/delta/cvi/test_out.jl") + ++ addtests(testrunner, "rules/dirichlet/test_marginals.jl") ++ addtests(testrunner, "rules/dirichlet/test_out.jl") ++ + addtests(testrunner, "rules/dot_product/test_out.jl") + addtests(testrunner, "rules/dot_product/test_in1.jl") + addtests(testrunner, "rules/dot_product/test_in2.jl") +diff --git a/test/variables/test_constant.jl b/test/variables/test_constant.jl +index 8453b723..0e349414 100644 +--- a/test/variables/test_constant.jl ++++ b/test/variables/test_constant.jl +@@ -4,13 +4,15 @@ using Test + using ReactiveMP + using Rocket + ++using LinearAlgebra: I ++ + import ReactiveMP: collection_type, VariableIndividual, VariableVector, VariableArray, linear_index + import ReactiveMP: getconst, proxy_variables + import ReactiveMP: israndom, isproxy + + @testset "ConstVariable" begin + @testset "Simple creation" begin +- for sym in (:x, :y, :z), value in (1.0, 1.0, "asd", [1.0, 1.0], [1.0 0.0; 0.0 1.0], (x) -> 1) ++ for sym in (:x, :y, :z), value in (1.0, 1.0, "asd", I, 0.3 * I, [1.0, 1.0], [1.0 0.0; 0.0 1.0], (x) -> 1) + v = constvar(sym, value) + + @test !israndom(v) +diff --git a/test/variables/test_data.jl b/test/variables/test_data.jl +index 683cf1a1..42b517a5 100644 +--- a/test/variables/test_data.jl ++++ b/test/variables/test_data.jl +@@ -4,10 +4,10 @@ using Test + using ReactiveMP + using Rocket + +-import ReactiveMP: DataVariableCreationOptions ++import ReactiveMP: DataVariableCreationOptions, MessageObservable + import ReactiveMP: collection_type, VariableIndividual, VariableVector, VariableArray, linear_index + import ReactiveMP: getconst, proxy_variables +-import ReactiveMP: israndom, isproxy, allows_missings ++import ReactiveMP: israndom, isproxy, isused, isconnected, setmessagein!, allows_missings + + @testset "DataVariable" begin + @testset "Simple creation" begin +@@ -44,10 +44,20 @@ import ReactiveMP: israndom, isproxy, allows_missings + @test !israndom(variable) + @test eltype(variable) === T + @test name(variable) === sym ++ @test allows_missings(variable) === allow_missings + @test collection_type(variable) isa VariableIndividual + @test proxy_variables(variable) === nothing + @test !isproxy(variable) +- @test allows_missings(variable) === allow_missings ++ @test !isused(variable) ++ @test !isconnected(variable) ++ ++ setmessagein!(variable, 1, MessageObservable()) ++ ++ @test isused(variable) ++ @test isconnected(variable) ++ ++ # `100` could a valid index, but messages should be initialized in order, previous was `1` ++ @test_throws ErrorException setmessagein!(variable, 100, MessageObservable()) + end + + for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), n in (10, 20), allow_missings in (true, false) +@@ -59,22 +69,24 @@ import ReactiveMP: israndom, isproxy, allows_missings + @test variables isa Vector + @test all(v -> !israndom(v), variables) + @test all(v -> name(v) === sym, variables) ++ @test all(v -> allows_missings(v) === allow_missings, variables) + @test all(v -> collection_type(v) isa VariableVector, variables) + @test all(t -> linear_index(collection_type(t[2])) === t[1], enumerate(variables)) + @test all(v -> eltype(v) === T, variables) + @test !isproxy(variables) + @test all(v -> !isproxy(v), variables) ++ @test all(v -> !isused(v), variables) ++ @test all(v -> !isconnected(v), variables) + @test test_updates(variables, T, (n,)) + +- @test all(v -> allows_missings(v) === allow_missings, variables) +- if allow_missings +- test_updates(variables, Missing, (n,)) +- end ++ foreach(v -> setmessagein!(v, 1, MessageObservable()), variables) ++ ++ @test all(v -> isused(v), variables) ++ @test all(v -> isconnected(v), variables) + end + +- for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), l in (10, 20), r in (10, 20), allow_missings in (true, false) +- options = DataVariableCreationOptions(T, nothing, Val(allow_missings)) +- for variables in (datavar(options, sym, T, l, r), datavar(options, sym, T, (l, r))) ++ for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), l in (10, 20), r in (10, 20) ++ for variables in (datavar(sym, T, l, r), datavar(sym, T, (l, r))) + @test !israndom(variables) + @test size(variables) === (l, r) + @test length(variables) === l * r +@@ -86,12 +98,13 @@ import ReactiveMP: israndom, isproxy, allows_missings + @test all(v -> eltype(v) === T, variables) + @test !isproxy(variables) + @test all(v -> !isproxy(v), variables) ++ @test all(v -> !isused(v), variables) + @test test_updates(variables, T, (l, r)) + +- @test all(v -> allows_missings(v) === allow_missings, variables) +- if allow_missings +- test_updates(variables, Missing, (l, r)) +- end ++ foreach(v -> setmessagein!(v, 1, MessageObservable()), variables) ++ ++ @test all(v -> isused(v), variables) ++ @test all(v -> isconnected(v), variables) + end + end + end diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index 1cc739b52..3ce4bd99f 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -10,17 +10,22 @@ const MvAutoregressive = MAR struct MARMeta order :: Int # order (lag) of MAR ds :: Int # dimensionality of MAR process, i.e., the number of correlated AR processes + Fs :: Vector{<:AbstractMatrix} # masks + es :: Vector{<:AbstractVector} # unit vectors function MARMeta(order, ds = 2) - if ds < 2 - @error "ds parameter should be > 1. Use AR node if ds = 1" - end - return new(order, ds) + @assert ds >= 2 "ds parameter should be > 1. Use AR node if ds = 1" + Fs = [mask_mar(order, ds, i) for i in 1:ds] + es = [uvector(ds, i) for i in 1:ds] + return new(order, ds, Fs, es) end end getorder(meta::MARMeta) = meta.order getdimensionality(meta::MARMeta) = meta.ds +getmasks(meta::MARMeta) = meta.Fs +getunits(meta::MARMeta) = meta.es + @node MAR Stochastic [y, x, a, Λ] diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index 0777efb7b..758486d2e 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -1,11 +1,12 @@ @rule MAR(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin order, ds = getorder(meta), getdimensionality(meta) - F = Multivariate - - dim = order * ds + Fs, es = getmasks(meta), getunits(meta) + dim = order * ds m, V = mean_cov(q_y_x) + + F = Multivariate my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index 29f88cbae..bb917c384 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -1,7 +1,9 @@ @rule MAR(:Λ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::MARMeta) = begin order, ds = getorder(meta), getdimensionality(meta) + Fs, es = getmasks(meta), getunits(meta) + dim = order * ds + F = Multivariate - dim = order * ds ma, Va = mean_cov(q_a) @@ -12,8 +14,6 @@ mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) Vyx = ar_slice(F, V, 1:dim, (dim + 1):(2dim)) - es = [uvector(dim, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] S = mar_shift(order, ds) G₁ = (my * my' + Vy)[1:ds, 1:ds] G₂ = ((my * mx' + Vyx) * mA')[1:ds, 1:ds] diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl index 92a71a659..2d789a442 100644 --- a/src/rules/mv_autoregressive/marginals.jl +++ b/src/rules/mv_autoregressive/marginals.jl @@ -9,8 +9,7 @@ function ar_y_x_marginal( m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta ) order, ds = getorder(meta), getdimensionality(meta) - F = Multivariate - dim = order * ds + Fs, es = getmasks(meta), getunits(meta) ma, Va = mean_cov(q_a) mΛ = mean(q_Λ) @@ -24,10 +23,6 @@ function ar_y_x_marginal( inv_b_Vy = cholinv(b_Vy) inv_f_Vx = cholinv(f_Vx) - # this should be inside MARMeta - es = [uvector(dim, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] - Ξ = inv_f_Vx + sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) W_11 = inv_b_Vy + mW diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index e191589ad..59f0ee29b 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -6,13 +6,11 @@ mΛ = mean(q_Λ) order, ds = getorder(meta), getdimensionality(meta) - dim = order * ds + Fs, es = getmasks(meta), getunits(meta) + dim = order * ds mA = mar_companion_matrix(order, ds, ma) mW = mar_transition(getorder(meta), mΛ) - # this should be inside MARMeta - es = [uvector(dim, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) @@ -31,15 +29,12 @@ end mΛ = mean(q_Λ) order, ds = getorder(meta), getdimensionality(meta) - dim = order * ds + Fs, es = getmasks(meta), getunits(meta) + dim = order * ds mA = mar_companion_matrix(order, ds, ma) mW = mar_transition(getorder(meta), mΛ) - # this should be inside MARMeta - es = [uvector(dim, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] - Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) Λ₀ = Hermitian(mA' * mW * mA) diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl index b99ace9b3..ba1122b80 100644 --- a/src/rules/mv_autoregressive/y.jl +++ b/src/rules/mv_autoregressive/y.jl @@ -5,13 +5,12 @@ mΛ = mean(q_Λ) order, ds = getorder(meta), getdimensionality(meta) + Fs, es = getmasks(meta), getunits(meta) + dim = order * ds mA = mar_companion_matrix(order, ds, ma) mW = mar_transition(getorder(meta), mΛ) - dim = order * ds - # this should be inside MARMeta - es = [uvector(dim, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] + Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) From dc836fe168e34a3e4328ab7617d3d145513a8c22 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sun, 19 Mar 2023 10:01:31 +0100 Subject: [PATCH 39/48] Fix MAR rules --- src/nodes/mv_autoregressive.jl | 3 +-- src/rules/mv_autoregressive/a.jl | 2 +- src/rules/mv_autoregressive/x.jl | 1 - src/rules/mv_autoregressive/y.jl | 2 -- 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index 3ce4bd99f..98e890dfd 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -16,7 +16,7 @@ struct MARMeta function MARMeta(order, ds = 2) @assert ds >= 2 "ds parameter should be > 1. Use AR node if ds = 1" Fs = [mask_mar(order, ds, i) for i in 1:ds] - es = [uvector(ds, i) for i in 1:ds] + es = [uvector(order * ds, i) for i in 1:ds] return new(order, ds, Fs, es) end end @@ -26,7 +26,6 @@ getdimensionality(meta::MARMeta) = meta.ds getmasks(meta::MARMeta) = meta.Fs getunits(meta::MARMeta) = meta.es - @node MAR Stochastic [y, x, a, Λ] default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag explicitly specified") diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index 758486d2e..9b09643b3 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -5,7 +5,7 @@ dim = order * ds m, V = mean_cov(q_y_x) - + F = Multivariate my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index 59f0ee29b..d5aa1f9c9 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -30,7 +30,6 @@ end order, ds = getorder(meta), getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) - dim = order * ds mA = mar_companion_matrix(order, ds, ma) mW = mar_transition(getorder(meta), mΛ) diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl index ba1122b80..56ad21c5e 100644 --- a/src/rules/mv_autoregressive/y.jl +++ b/src/rules/mv_autoregressive/y.jl @@ -6,12 +6,10 @@ order, ds = getorder(meta), getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) - dim = order * ds mA = mar_companion_matrix(order, ds, ma) mW = mar_transition(getorder(meta), mΛ) - Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) Ξ = Λ + Wx From a3fcc3743027320a85a7dcedce8dbe178e37c123 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sun, 19 Mar 2023 19:18:06 +0100 Subject: [PATCH 40/48] Decrease allocs --- src/nodes/mv_autoregressive.jl | 38 +++++++++++------------- src/rules/mv_autoregressive/a.jl | 5 ---- src/rules/mv_autoregressive/lambda.jl | 8 ++--- src/rules/mv_autoregressive/marginals.jl | 2 +- src/rules/mv_autoregressive/x.jl | 4 +-- src/rules/mv_autoregressive/y.jl | 4 +-- 6 files changed, 24 insertions(+), 37 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index 98e890dfd..3eb47451a 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -36,24 +36,20 @@ default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag expl mΛ = mean(q_Λ) order, ds = getorder(meta), getdimensionality(meta) - F = Multivariate + Fs, es = getmasks(meta), getunits(meta) + dim = order * ds + F = Multivariate + n = div(ndims(q_y_x), 2) ma, Va = mean_cov(q_a) - mA = mar_companion_matrix(order, ds, ma)[1:ds, 1:dim] + mA = mar_companion_matrix(ma, meta)[1:ds, 1:dim] mx, Vx = ar_slice(F, myx, (dim + 1):(2dim)), ar_slice(F, Vyx, (dim + 1):(2dim), (dim + 1):(2dim)) my1, Vy1 = myx[1:ds], Vyx[1:ds, 1:ds] Vy1x = ar_slice(F, Vyx, 1:ds, (dim + 1):(2dim)) - # @show Vyx - # @show Vy1x - - # this should be inside MARMeta - es = [uvector(ds, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] - g₁ = my1' * mΛ * my1 + tr(Vy1 * mΛ) g₂ = mx' * mA' * mΛ * my1 + tr(Vy1x * mA' * mΛ) g₃ = g₂ @@ -82,12 +78,13 @@ end mΛ = mean(q_Λ) order, ds = getorder(meta), getdimensionality(meta) - F = Multivariate + Fs, es = getmasks(meta), getunits(meta) + dim = order * ds - n = dim + F = Multivariate ma, Va = mean_cov(q_a) - mA = mar_companion_matrix(order, ds, ma)[1:ds, 1:dim] + mA = mar_companion_matrix(ma, meta)[1:ds, 1:dim] my1, Vy1 = my[1:ds], Vy[1:ds, 1:ds] @@ -100,7 +97,7 @@ end g₃ = -g₂ G = sum(sum(es[i]' * mΛ * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:ds) for j in 1:ds) g₄ = mx' * G * mx + tr(Vx * G) - AE = n / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ + g₂ + g₃ + g₄) + AE = dim / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ + g₂ + g₃ + g₄) if order > 1 AE += entropy(q_y) @@ -137,10 +134,8 @@ end function mar_shift(order, ds) dim = order * ds S = diageye(dim) - for i in dim:-1:(ds + 1) - S[i, :] = S[i - ds, :] - end - S[1:ds, :] = zeros(ds, dim) + S = circshift(S, ds) + S[:, (end - ds + 1):end] = zeros(dim, ds) return S end @@ -150,11 +145,12 @@ function uvector(dim, pos = 1) return dim == 1 ? u[pos] : u end -function mar_companion_matrix(order, ds, a) - dim = order * ds +function mar_companion_matrix(a, meta::MARMeta) + order, ds = getorder(meta), getdimensionality(meta) + Fs, es = getmasks(meta), getunits(meta) + dim = order * ds + S = mar_shift(order, ds) - es = [uvector(dim, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] L = S .+ sum(es[i] * a' * Fs[i]' for i in 1:ds) return L end diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index 9b09643b3..7fe729e9a 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -15,11 +15,6 @@ mΛ = mean(q_Λ) mW = mar_transition(order, mΛ) - # this should be inside MARMeta - es = [uvector(dim, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] - S = mar_shift(order, ds) - # NOTE: prove that sum(Fs[i]'*((mx*mx'+Vx')*S')*mW*es[i] for i in 1:ds) == 0.0 D = sum(sum(es[i]' * mW * es[j] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:ds) for j in 1:ds) z = sum(Fs[i]' * (mx * my' + Vyx') * mW * es[i] for i in 1:ds) diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index bb917c384..79abece1c 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -7,14 +7,13 @@ ma, Va = mean_cov(q_a) - mA = mar_companion_matrix(order, ds, ma) + mA = mar_companion_matrix(ma, meta) m, V = mean_cov(q_y_x) my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) Vyx = ar_slice(F, V, 1:dim, (dim + 1):(2dim)) - S = mar_shift(order, ds) G₁ = (my * my' + Vy)[1:ds, 1:ds] G₂ = ((my * mx' + Vyx) * mA')[1:ds, 1:ds] G₃ = transpose(G₂) @@ -36,11 +35,8 @@ end mx, Vx = mean_cov(q_x) ma, Va = mean_cov(q_a) - mA = mar_companion_matrix(order, ds, ma) + mA = mar_companion_matrix(ma, meta) - es = [uvector(dim, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] - S = mar_shift(order, ds) G₁ = (my * my' + Vy)[1:ds, 1:ds] G₂ = (my * mx' * mA')[1:ds, 1:ds] G₃ = transpose(G₂) diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl index 2d789a442..213a69061 100644 --- a/src/rules/mv_autoregressive/marginals.jl +++ b/src/rules/mv_autoregressive/marginals.jl @@ -14,7 +14,7 @@ function ar_y_x_marginal( ma, Va = mean_cov(q_a) mΛ = mean(q_Λ) - mA = mar_companion_matrix(order, ds, ma) + mA = mar_companion_matrix(ma, meta) mW = mar_transition(getorder(meta), mΛ) b_my, b_Vy = mean_cov(m_y) diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl index d5aa1f9c9..cafd80461 100644 --- a/src/rules/mv_autoregressive/x.jl +++ b/src/rules/mv_autoregressive/x.jl @@ -9,7 +9,7 @@ Fs, es = getmasks(meta), getunits(meta) dim = order * ds - mA = mar_companion_matrix(order, ds, ma) + mA = mar_companion_matrix(ma, meta) mW = mar_transition(getorder(meta), mΛ) Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) @@ -31,7 +31,7 @@ end order, ds = getorder(meta), getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) - mA = mar_companion_matrix(order, ds, ma) + mA = mar_companion_matrix(ma, meta) mW = mar_transition(getorder(meta), mΛ) Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl index 56ad21c5e..6eede850b 100644 --- a/src/rules/mv_autoregressive/y.jl +++ b/src/rules/mv_autoregressive/y.jl @@ -7,7 +7,7 @@ order, ds = getorder(meta), getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) - mA = mar_companion_matrix(order, ds, ma) + mA = mar_companion_matrix(ma, meta) mW = mar_transition(getorder(meta), mΛ) Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) @@ -24,7 +24,7 @@ end @rule MAR(:y, Marginalisation) (q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin order, ds = getorder(meta), getdimensionality(meta) - mA = mar_companion_matrix(order, ds, mean(q_a)) + mA = mar_companion_matrix(mean(q_a), meta) mW = mar_transition(getorder(meta), mean(q_Λ)) return MvNormalMeanPrecision(mA * mean(q_x), mW) From db06986d203291a0ecf888fa45a5d8233ca78d28 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 21 Mar 2023 22:24:43 +0100 Subject: [PATCH 41/48] Optmize functions --- src/nodes/mv_autoregressive.jl | 33 +++++++++++---------------- src/rules/mv_autoregressive/a.jl | 9 ++------ src/rules/mv_autoregressive/lambda.jl | 26 ++++++++++----------- 3 files changed, 27 insertions(+), 41 deletions(-) diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl index 3eb47451a..f0c969362 100644 --- a/src/nodes/mv_autoregressive.jl +++ b/src/nodes/mv_autoregressive.jl @@ -88,10 +88,6 @@ end my1, Vy1 = my[1:ds], Vy[1:ds, 1:ds] - # this should be inside MARMeta - es = [uvector(ds, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] - g₁ = my1' * mΛ * my1 + tr(Vy1 * mΛ) g₂ = -mx' * mA' * mΛ * my1 g₃ = -g₂ @@ -111,22 +107,21 @@ end # Helpers for AR rules function mask_mar(order, dimension, index) F = zeros(dimension * order, dimension * dimension * order) - rows = repeat([dimension], order) - cols = repeat([dimension], dimension * order) - FB = BlockArrays.BlockArray(F, rows, cols) - for k in 1:order - for j in 1:(dimension * order) - if j == index + (k - 1) * dimension - view(FB, BlockArrays.Block(k, j)) .= diageye(dimension) - end - end + + @inbounds for k in 1:order + start_col = (k - 1) * dimension^2 + (index - 1) * dimension + 1 + end_col = start_col + dimension - 1 + start_row = (k - 1) * dimension + 1 + end_row = start_row + dimension - 1 + F[start_row:end_row, start_col:end_col] = I(dimension) end - return Matrix(FB) + + return F end function mar_transition(order, Λ) dim = size(Λ, 1) - W = 1.0 * diageye(dim * order) + W = diageye(dim * order) W[1:dim, 1:dim] = Λ return W end @@ -135,7 +130,7 @@ function mar_shift(order, ds) dim = order * ds S = diageye(dim) S = circshift(S, ds) - S[:, (end - ds + 1):end] = zeros(dim, ds) + S[:, (end - ds + 1):end] .= 0 return S end @@ -148,9 +143,7 @@ end function mar_companion_matrix(a, meta::MARMeta) order, ds = getorder(meta), getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) - dim = order * ds - S = mar_shift(order, ds) - L = S .+ sum(es[i] * a' * Fs[i]' for i in 1:ds) + L = mar_shift(order, ds) .+ sum(es[i] * a' * Fs[i]' for i in 1:ds) return L -end +end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl index 7fe729e9a..41b488646 100644 --- a/src/rules/mv_autoregressive/a.jl +++ b/src/rules/mv_autoregressive/a.jl @@ -24,9 +24,8 @@ end @rule MAR(:a, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin order, ds = getorder(meta), getdimensionality(meta) - F = Multivariate - - dim = order * ds + Fs, es = getmasks(meta), getunits(meta) + dim = order * ds my, Vy = mean_cov(q_y) mx, Vx = mean_cov(q_x) @@ -35,10 +34,6 @@ end mW = mar_transition(order, mΛ) S = mar_shift(order, ds) - # this should be inside MARMeta - es = [uvector(dim, i) for i in 1:ds] - Fs = [mask_mar(order, ds, i) for i in 1:ds] - D = sum(sum(es[j]' * mW * es[i] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:ds) for j in 1:ds) z = sum(Fs[i]' * ((mx * mx' + Vx') * S' + mx * my') * mW * es[i] for i in 1:ds) diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl index 79abece1c..5e859f315 100644 --- a/src/rules/mv_autoregressive/lambda.jl +++ b/src/rules/mv_autoregressive/lambda.jl @@ -1,3 +1,13 @@ +function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, ds, Fs, es) + G₁ = (my * my' + Vy)[1:ds, 1:ds] + G₂ = ((my * mx' + Vyx) * mA')[1:ds, 1:ds] + G₃ = transpose(G₂) + Ex_xx = mx * mx' + Vx + G₅ = sum(sum(es[i] * ma' * Fs[i]'Ex_xx * Fs[j] * ma * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] + G₆ = sum(sum(es[i] * tr(Fs[i]' * Ex_xx * Fs[j] * Va) * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] + Δ = G₁ - G₂ - G₃ + G₅ + G₆ +end + @rule MAR(:Λ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::MARMeta) = begin order, ds = getorder(meta), getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) @@ -14,13 +24,7 @@ mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) Vyx = ar_slice(F, V, 1:dim, (dim + 1):(2dim)) - G₁ = (my * my' + Vy)[1:ds, 1:ds] - G₂ = ((my * mx' + Vyx) * mA')[1:ds, 1:ds] - G₃ = transpose(G₂) - Ex_xx = mx * mx' + Vx - G₅ = sum(sum(es[i] * ma' * Fs[i]'Ex_xx * Fs[j] * ma * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - G₆ = sum(sum(es[i] * tr(Fs[i]' * Ex_xx * Fs[j] * Va) * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - Δ = G₁ - G₂ - G₃ + G₅ + G₆ + Δ = compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, ds, Fs, es) return WishartMessage(ds + 2, Δ) end @@ -37,13 +41,7 @@ end mA = mar_companion_matrix(ma, meta) - G₁ = (my * my' + Vy)[1:ds, 1:ds] - G₂ = (my * mx' * mA')[1:ds, 1:ds] - G₃ = transpose(G₂) - Ex_xx = mx * mx' + Vx - G₅ = sum(sum(es[i] * ma' * Fs[j]'Ex_xx * Fs[i] * ma * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - G₆ = sum(sum(es[i] * tr(Va * Fs[i]' * Ex_xx * Fs[j]) * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - Δ = G₁ - G₂ - G₃ + G₅ + G₆ + Δ = compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, ds, Fs, es) return WishartMessage(ds + 2, Δ) end From b89e419a460aaff23247865e903e2109a462ed3e Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Thu, 30 Mar 2023 14:48:19 +0200 Subject: [PATCH 42/48] Remove diffs --- branch.diff | 1963 --------------------------------------------------- 1 file changed, 1963 deletions(-) delete mode 100644 branch.diff diff --git a/branch.diff b/branch.diff deleted file mode 100644 index b14a433b6..000000000 --- a/branch.diff +++ /dev/null @@ -1,1963 +0,0 @@ -diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml -index 8d5d2580..93d16876 100644 ---- a/.github/workflows/ci.yml -+++ b/.github/workflows/ci.yml -@@ -1,9 +1,17 @@ - name: CI - on: - pull_request: -+ types: [ready_for_review,reopened,synchronize] -+ pull_request_review: -+ types: [submitted,edited] - push: -+ branches: -+ - 'master' -+ tags: '*' -+ check_run: -+ types: [rerequested] - schedule: -- - cron: '44 9 16 * *' # run the cron job one time per month -+ - cron: '0 8 * * 1' # run the cron job one time per week on Monday 8:00 AM - jobs: - format: - name: Julia Formatter -@@ -17,6 +25,7 @@ jobs: - test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} - runs-on: ${{ matrix.os }} -+ continue-on-error: ${{ contains(matrix.version, 'nightly') }} - needs: format - strategy: - fail-fast: false -@@ -110,4 +119,4 @@ jobs: - env: - PYTHON: "" - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -- DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} -\ No newline at end of file -+ DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} -diff --git a/Makefile b/Makefile -index 95a2851e..231f80bb 100644 ---- a/Makefile -+++ b/Makefile -@@ -30,8 +30,8 @@ docs: doc_init ## Generate documentation - - .PHONY: test - --test: ## Run tests (use testset="folder1:test1 folder2:test2" argument to run reduced testset) -- julia -e 'import Pkg; Pkg.activate("."); Pkg.test(test_args = split("$(testset)") .|> string)' -+test: ## Run tests (use test_args="folder1:test1 folder2:test2" argument to run reduced testset) -+ julia -e 'import Pkg; Pkg.activate("."); Pkg.test(test_args = split("$(test_args)") .|> string)' - - help: ## Display this help - @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_-]+:.*?##/ { printf " \033[36m%-24s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) -\ No newline at end of file -diff --git a/Project.toml b/Project.toml -index d63ca068..c4754725 100644 ---- a/Project.toml -+++ b/Project.toml -@@ -1,10 +1,9 @@ - name = "ReactiveMP" - uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3" - authors = ["Dmitry Bagaev ", "Albert Podusenko ", "Bart van Erp ", "Ismail Senoz "] --version = "3.6.1" -+version = "3.7.2" - - [deps] --BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" - DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" - Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" - DomainIntegrals = "cc6bae93-f070-4015-88fd-838f9505a86c" -@@ -43,7 +42,7 @@ MacroTools = "0.5" - Optim = "1.0.0" - PositiveFactorizations = "0.2" - Requires = "1" --Rocket = "1.6.0" -+Rocket = "1.7.0" - SpecialFunctions = "1.4, 2" - StaticArrays = "1.2" - StatsBase = "0.33" -diff --git a/docs/src/extra/contributing.md b/docs/src/extra/contributing.md -index c9e0e317..9a84ca4c 100644 ---- a/docs/src/extra/contributing.md -+++ b/docs/src/extra/contributing.md -@@ -81,8 +81,8 @@ a new release of the broken dependecy is available. - - - `make help`: Shows help snippet - - `make test`: Run tests, supports extra arguments -- - `make test testset="distributions:normal_mean_variance"` would run tests only from `distributions/test_normal_mean_variance.jl` -- - `make test testset="distributions:normal_mean_variance models:lgssm"` would run tests both from `distributions/test_normal_mean_variance.jl` and `models/test_lgssm.jl` -+ - `make test test_args="distributions:normal_mean_variance"` would run tests only from `distributions/test_normal_mean_variance.jl` -+ - `make test test_args="distributions:normal_mean_variance models:lgssm"` would run tests both from `distributions/test_normal_mean_variance.jl` and `models/test_lgssm.jl` - - `make docs`: Compile documentation - - `make benchmark`: Run simple benchmark - - `make lint`: Check codestyle -diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl -index 7b5c860c..fc0186b4 100644 ---- a/src/ReactiveMP.jl -+++ b/src/ReactiveMP.jl -@@ -146,7 +146,6 @@ include("nodes/matrix_dirichlet.jl") - include("nodes/dirichlet.jl") - include("nodes/bernoulli.jl") - include("nodes/gcv.jl") --include("nodes/kernel_gcv.jl") - include("nodes/wishart.jl") - include("nodes/wishart_inverse.jl") - include("nodes/normal_mixture.jl") -@@ -154,7 +153,6 @@ include("nodes/gamma_mixture.jl") - include("nodes/dot_product.jl") - include("nodes/transition.jl") - include("nodes/autoregressive.jl") --include("nodes/mv_autoregressive.jl") - include("nodes/bifm.jl") - include("nodes/bifm_helper.jl") - include("nodes/probit.jl") -diff --git a/src/constraints/form.jl b/src/constraints/form.jl -index a746cdd6..c6c2edc4 100644 ---- a/src/constraints/form.jl -+++ b/src/constraints/form.jl -@@ -172,7 +172,7 @@ function is_point_mass_form_constraint(composite::CompositeFormConstraint) - is_point_mass = map(is_point_mass_form_constraint, composite.constraints) - pmindex = findnext(is_point_mass, 1) - if pmindex !== nothing && pmindex !== length(is_point_mass) -- error("Composite form constraint supports point mass constraint only at the end of the form constrains specification.") -+ error("Composite form constraint supports point mass constraint only at the end of the form constraints specification.") - end - return last(is_point_mass) - end -diff --git a/src/distributions/bernoulli.jl b/src/distributions/bernoulli.jl -index 9757f0aa..bb30b485 100644 ---- a/src/distributions/bernoulli.jl -+++ b/src/distributions/bernoulli.jl -@@ -47,6 +47,10 @@ function prod(::ProdAnalytical, left::Bernoulli, right::Categorical) - return Categorical(ReactiveMP.normalize!(p_new, 1)) - end - -+prod_analytical_rule(::Type{<:Categorical}, ::Type{<:Bernoulli}) = ProdAnalyticalRuleAvailable() -+ -+prod(::ProdAnalytical, left::Categorical, right::Bernoulli) = prod(ProdAnalytical(), right, left) -+ - function compute_logscale(new_dist::Bernoulli, left_dist::Bernoulli, right_dist::Bernoulli) - left_p = succprob(left_dist) - right_p = succprob(right_dist) -diff --git a/src/distributions/beta.jl b/src/distributions/beta.jl -index 9541c90a..129c4f7f 100644 ---- a/src/distributions/beta.jl -+++ b/src/distributions/beta.jl -@@ -1,7 +1,9 @@ - export Beta -+export BetaNaturalParameters - - import Distributions: Beta, params --import SpecialFunctions: digamma, logbeta -+import SpecialFunctions: digamma, logbeta, loggamma -+import StatsFuns: betalogpdf - - vague(::Type{<:Beta}) = Beta(1.0, 1.0) - -@@ -27,3 +29,47 @@ function mean(::typeof(mirrorlog), dist::Beta) - a, b = params(dist) - return digamma(b) - digamma(a + b) - end -+ -+struct BetaNaturalParameters{T <: Real} <: NaturalParameters -+ αm1::T -+ βm1::T -+end -+ -+BetaNaturalParameters(αm1::Real, βm1::Real) = BetaNaturalParameters(promote(αm1, βm1)...) -+BetaNaturalParameters(αm1::Integer, βm1::Integer) = BetaNaturalParameters(float(αm1), float(βm1)) -+ -+Base.convert(::Type{BetaNaturalParameters}, a::Real, b::Real) = convert(BetaNaturalParameters{promote_type(typeof(a), typeof(b))}, a, b) -+ -+Base.convert(::Type{BetaNaturalParameters{T}}, a::Real, b::Real) where {T} = BetaNaturalParameters(convert(T, a), convert(T, b)) -+ -+Base.convert(::Type{BetaNaturalParameters}, vec::AbstractVector) = convert(BetaNaturalParameters{eltype(vec)}, vec) -+ -+Base.convert(::Type{BetaNaturalParameters{T}}, vec::AbstractVector) where {T} = BetaNaturalParameters(convert(AbstractVector{T}, vec)) -+ -+function isproper(params::BetaNaturalParameters) -+ return ((params.αm1 + 1) > 0) && ((params.βm1 + 1) > 0) -+end -+ -+naturalparams(dist::Beta) = BetaNaturalParameters(dist.α - 1, dist.β - 1) -+ -+function Base.convert(::Type{Distribution}, η::BetaNaturalParameters) -+ return Beta(η.αm1 + 1, η.βm1 + 1, check_args = false) -+end -+ -+function Base.vec(p::BetaNaturalParameters) -+ return [p.αm1, p.βm1] -+end -+ -+ReactiveMP.as_naturalparams(::Type{T}, args...) where {T <: BetaNaturalParameters} = convert(BetaNaturalParameters, args...) -+ -+function BetaNaturalParameters(v::AbstractVector{T}) where {T <: Real} -+ @assert length(v) === 2 "`BetaNaturalParameters` must accept a vector of length `2`." -+ return BetaNaturalParameters(v[1], v[2]) -+end -+ -+lognormalizer(params::BetaNaturalParameters) = logbeta(params.αm1 + 1, params.βm1 + 1) -+logpdf(params::BetaNaturalParameters, x) = betalogpdf(params.αm1 + 1, params.βm1 + 1, x) -+ -+function Base.:-(left::BetaNaturalParameters, right::BetaNaturalParameters) -+ return BetaNaturalParameters(left.αm1 - right.αm1, left.βm1 - right.βm1) -+end -diff --git a/src/distributions/mv_normal_mean_covariance.jl b/src/distributions/mv_normal_mean_covariance.jl -index 7938bf16..079016ca 100644 ---- a/src/distributions/mv_normal_mean_covariance.jl -+++ b/src/distributions/mv_normal_mean_covariance.jl -@@ -26,6 +26,13 @@ function MvNormalMeanCovariance(μ::AbstractVector{T}) where {T} - return MvNormalMeanCovariance(μ, convert(AbstractArray{T}, ones(length(μ)))) - end - -+function MvNormalMeanCovariance(μ::AbstractVector{T1}, Σ::UniformScaling{T2}) where {T1, T2} -+ T = promote_type(T1, T2) -+ μ_new = convert(AbstractArray{T}, μ) -+ Σ_new = convert(UniformScaling{T}, Σ)(length(μ)) -+ return MvNormalMeanCovariance(μ_new, Σ_new) -+end -+ - Distributions.distrname(::MvNormalMeanCovariance) = "MvNormalMeanCovariance" - - function weightedmean(dist::MvNormalMeanCovariance) -@@ -88,7 +95,9 @@ function Base.prod(::ProdAnalytical, left::MvNormalMeanCovariance, right::MvNorm - return MvNormalWeightedMeanPrecision(xi_left + xi_right, W_left + W_right) - end - --function Base.prod(::ProdAnalytical, left::MvNormalMeanCovariance{T1}, right::MvNormalMeanCovariance{T2}) where {T1 <: LinearAlgebra.BlasFloat, T2 <: LinearAlgebra.BlasFloat} -+function Base.prod( -+ ::ProdAnalytical, left::MvNormalMeanCovariance{T1, <:AbstractVector, <:Matrix}, right::MvNormalMeanCovariance{T2, <:AbstractVector, <:Matrix} -+) where {T1 <: LinearAlgebra.BlasFloat, T2 <: LinearAlgebra.BlasFloat} - - # start with parameters of left - xi, W = weightedmean_precision(left) -diff --git a/src/distributions/mv_normal_mean_precision.jl b/src/distributions/mv_normal_mean_precision.jl -index 889b0014..815eb424 100644 ---- a/src/distributions/mv_normal_mean_precision.jl -+++ b/src/distributions/mv_normal_mean_precision.jl -@@ -26,6 +26,13 @@ function MvNormalMeanPrecision(μ::AbstractVector{T}) where {T} - return MvNormalMeanPrecision(μ, convert(AbstractArray{T}, ones(length(μ)))) - end - -+function MvNormalMeanPrecision(μ::AbstractVector{T1}, Λ::UniformScaling{T2}) where {T1, T2} -+ T = promote_type(T1, T2) -+ μ_new = convert(AbstractArray{T}, μ) -+ Λ_new = convert(UniformScaling{T}, Λ)(length(μ)) -+ return MvNormalMeanPrecision(μ_new, Λ_new) -+end -+ - Distributions.distrname(::MvNormalMeanPrecision) = "MvNormalMeanPrecision" - - weightedmean(dist::MvNormalMeanPrecision) = precision(dist) * mean(dist) -@@ -92,7 +99,9 @@ function Base.prod(::ProdAnalytical, left::MvNormalMeanPrecision, right::MvNorma - return MvNormalWeightedMeanPrecision(xi, W) - end - --function Base.prod(::ProdAnalytical, left::MvNormalMeanPrecision{T1}, right::MvNormalMeanPrecision{T2}) where {T1 <: LinearAlgebra.BlasFloat, T2 <: LinearAlgebra.BlasFloat} -+function Base.prod( -+ ::ProdAnalytical, left::MvNormalMeanPrecision{T1, <:AbstractVector, <:Matrix}, right::MvNormalMeanPrecision{T2, <:AbstractVector, <:Matrix} -+) where {T1 <: LinearAlgebra.BlasFloat, T2 <: LinearAlgebra.BlasFloat} - W = precision(left) + precision(right) - - # fast & efficient implementation of xi = precision(right)*mean(right) + precision(left)*mean(left) -diff --git a/src/distributions/mv_normal_weighted_mean_precision.jl b/src/distributions/mv_normal_weighted_mean_precision.jl -index 86aa046e..7137ccb4 100644 ---- a/src/distributions/mv_normal_weighted_mean_precision.jl -+++ b/src/distributions/mv_normal_weighted_mean_precision.jl -@@ -26,6 +26,13 @@ function MvNormalWeightedMeanPrecision(xi::AbstractVector{T}) where {T} - return MvNormalWeightedMeanPrecision(xi, convert(AbstractArray{T}, ones(length(xi)))) - end - -+function MvNormalWeightedMeanPrecision(xi::AbstractVector{T1}, Λ::UniformScaling{T2}) where {T1, T2} -+ T = promote_type(T1, T2) -+ xi_new = convert(AbstractArray{T}, xi) -+ Λ_new = convert(UniformScaling{T}, Λ)(length(xi)) -+ return MvNormalWeightedMeanPrecision(xi_new, Λ_new) -+end -+ - Distributions.distrname(::MvNormalWeightedMeanPrecision) = "MvNormalWeightedMeanPrecision" - - weightedmean(dist::MvNormalWeightedMeanPrecision) = dist.xi -diff --git a/src/distributions/normal_mean_precision.jl b/src/distributions/normal_mean_precision.jl -index d6b9ef8a..b5083756 100644 ---- a/src/distributions/normal_mean_precision.jl -+++ b/src/distributions/normal_mean_precision.jl -@@ -11,6 +11,12 @@ NormalMeanPrecision(μ::Real, w::Real) = NormalMeanPrecision(promote(μ, w - NormalMeanPrecision(μ::Integer, w::Integer) = NormalMeanPrecision(float(μ), float(w)) - NormalMeanPrecision(μ::Real) = NormalMeanPrecision(μ, one(μ)) - NormalMeanPrecision() = NormalMeanPrecision(0.0, 1.0) -+function NormalMeanPrecision(μ::T1, w::UniformScaling{T2}) where {T1 <: Real, T2} -+ T = promote_type(T1, T2) -+ μ_new = convert(T, μ) -+ w_new = convert(T, w.λ) -+ return NormalMeanPrecision(μ_new, w_new) -+end - - Distributions.@distr_support NormalMeanPrecision -Inf Inf - -diff --git a/src/distributions/normal_mean_variance.jl b/src/distributions/normal_mean_variance.jl -index 759f319d..2b1770b9 100644 ---- a/src/distributions/normal_mean_variance.jl -+++ b/src/distributions/normal_mean_variance.jl -@@ -11,6 +11,12 @@ NormalMeanVariance(μ::Real, v::Real) = NormalMeanVariance(promote(μ, v). - NormalMeanVariance(μ::Integer, v::Integer) = NormalMeanVariance(float(μ), float(v)) - NormalMeanVariance(μ::T) where {T <: Real} = NormalMeanVariance(μ, one(T)) - NormalMeanVariance() = NormalMeanVariance(0.0, 1.0) -+function NormalMeanVariance(μ::T1, v::UniformScaling{T2}) where {T1 <: Real, T2} -+ T = promote_type(T1, T2) -+ μ_new = convert(T, μ) -+ v_new = convert(T, v.λ) -+ return NormalMeanVariance(μ_new, v_new) -+end - - Distributions.@distr_support NormalMeanVariance -Inf Inf - -diff --git a/src/distributions/normal_weighted_mean_precision.jl b/src/distributions/normal_weighted_mean_precision.jl -index f68a6abe..c17f9fb3 100644 ---- a/src/distributions/normal_weighted_mean_precision.jl -+++ b/src/distributions/normal_weighted_mean_precision.jl -@@ -11,6 +11,12 @@ NormalWeightedMeanPrecision(xi::Real, w::Real) = NormalWeightedMeanPrecisi - NormalWeightedMeanPrecision(xi::Integer, w::Integer) = NormalWeightedMeanPrecision(float(xi), float(w)) - NormalWeightedMeanPrecision(xi::Real) = NormalWeightedMeanPrecision(xi, one(xi)) - NormalWeightedMeanPrecision() = NormalWeightedMeanPrecision(0.0, 1.0) -+function NormalWeightedMeanPrecision(xi::T1, w::UniformScaling{T2}) where {T1 <: Real, T2} -+ T = promote_type(T1, T2) -+ xi_new = convert(T, xi) -+ w_new = convert(T, w.λ) -+ return NormalWeightedMeanPrecision(xi_new, w_new) -+end - - Distributions.@distr_support NormalWeightedMeanPrecision -Inf Inf - -diff --git a/src/distributions/pointmass.jl b/src/distributions/pointmass.jl -index 38c15c24..c4612623 100644 ---- a/src/distributions/pointmass.jl -+++ b/src/distributions/pointmass.jl -@@ -1,5 +1,7 @@ - export PointMass, getpointmass - -+using LinearAlgebra: UniformScaling, I -+ - import Distributions: mean, var, cov, std, insupport, pdf, logpdf, entropy - import Base: ndims, precision, getindex, size, convert, isapprox, eltype - import SpecialFunctions: loggamma, logbeta -@@ -13,12 +15,14 @@ end - variate_form(::PointMass{T}) where {T <: Real} = Univariate - variate_form(::PointMass{V}) where {T, V <: AbstractVector{T}} = Multivariate - variate_form(::PointMass{M}) where {T, M <: AbstractMatrix{T}} = Matrixvariate -+variate_form(::PointMass{U}) where {T, U <: UniformScaling{T}} = Matrixvariate - - ## - - sampletype(distribution::PointMass{T}) where {T} = T - - getpointmass(distribution::PointMass) = distribution.point -+getpointmass(point::Union{Real, AbstractArray}) = point - - ## - -@@ -111,6 +115,31 @@ convert_eltype(::Type{PointMass}, ::Type{T}, distribution::PointMass{R}) where { - - Base.eltype(::PointMass{M}) where {T <: Real, M <: AbstractMatrix{T}} = T - -+# UniformScaling-based matrixvariate point mass -+ -+Distributions.insupport(distribution::PointMass{M}, x::UniformScaling) where {T <: Real, M <: UniformScaling{T}} = x == getpointmass(distribution) -+Distributions.pdf(distribution::PointMass{M}, x::UniformScaling) where {T <: Real, M <: UniformScaling{T}} = Distributions.insupport(distribution, x) ? one(T) : zero(T) -+Distributions.logpdf(distribution::PointMass{M}, x::UniformScaling) where {T <: Real, M <: UniformScaling{T}} = Distributions.insupport(distribution, x) ? zero(T) : convert(T, -Inf) -+ -+Distributions.mean(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = getpointmass(distribution) -+Distributions.mode(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = mean(distribution) -+Distributions.var(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = zero(T) * I -+Distributions.std(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = zero(T) * I -+Distributions.cov(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = error("Distributions.cov(::PointMass{ <: UniformScaling }) is not defined") -+ -+probvec(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = error("probvec(::PointMass{ <: UniformScaling }) is not defined") -+ -+mean(::typeof(inv), distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = inv(mean(distribution)) -+mean(::typeof(cholinv), distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = inv(mean(distribution)) -+ -+Base.precision(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = one(T) ./ cov(distribution) -+Base.ndims(distribution::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = size(mean(distribution)) -+ -+convert_eltype(::Type{PointMass}, ::Type{T}, distribution::PointMass{R}) where {T <: Real, R <: UniformScaling} = PointMass(convert(AbstractMatrix{T}, getpointmass(distribution))) -+convert_eltype(::Type{PointMass}, ::Type{T}, distribution::PointMass{R}) where {T <: AbstractMatrix, R <: UniformScaling} = PointMass(convert(T, getpointmass(distribution))) -+ -+Base.eltype(::PointMass{M}) where {T <: Real, M <: UniformScaling{T}} = T -+ - Base.isapprox(left::PointMass, right::PointMass; kwargs...) = Base.isapprox(getpointmass(left), getpointmass(right); kwargs...) - Base.isapprox(left::PointMass, right; kwargs...) = false - Base.isapprox(left, right::PointMass; kwargs...) = false -diff --git a/src/marginal.jl b/src/marginal.jl -index 72cd8819..75a39143 100644 ---- a/src/marginal.jl -+++ b/src/marginal.jl -@@ -109,6 +109,11 @@ struct SkipInitial <: MarginalSkipStrategy end - struct SkipClampedAndInitial <: MarginalSkipStrategy end - struct IncludeAll <: MarginalSkipStrategy end - -+Base.broadcastable(::SkipClamped) = Ref(SkipClamped()) -+Base.broadcastable(::SkipInitial) = Ref(SkipInitial()) -+Base.broadcastable(::SkipClampedAndInitial) = Ref(SkipClampedAndInitial()) -+Base.broadcastable(::IncludeAll) = Ref(IncludeAll()) -+ - apply_skip_filter(observable, ::SkipClamped) = observable |> filter(v -> !is_clamped(v)) - apply_skip_filter(observable, ::SkipInitial) = observable |> filter(v -> !is_initial(v)) - apply_skip_filter(observable, ::SkipClampedAndInitial) = observable |> filter(v -> !is_initial(v) && !is_clamped(v)) -diff --git a/src/node.jl b/src/node.jl -index 56b5e5e1..605d0c78 100644 ---- a/src/node.jl -+++ b/src/node.jl -@@ -866,7 +866,7 @@ function activate!(factornode::AbstractFactorNode, options) - vmessageout = combineLatest((msgs_observable, marginals_observable), PushNew()) # TODO check PushEach - vmessageout = apply_pipeline_stage(get_pipeline_stages(interface), factornode, vtag, vmessageout) - -- mapping = let messagemap = MessageMapping(fform, vtag, vconstraint, msgs_names, marginal_names, meta, addons, factornode) -+ mapping = let messagemap = MessageMapping(fform, vtag, vconstraint, msgs_names, marginal_names, meta, addons, node_if_required(fform, factornode)) - (dependencies) -> VariationalMessage(dependencies[1], dependencies[2], messagemap) - end - -@@ -939,11 +939,11 @@ function getmarginal!(factornode::FactorNode, localmarginal::FactorNodeLocalMarg - vtag = Val{name(localmarginal)} - meta = metadata(factornode) - -- mapping = MarginalMapping(fform, vtag, msgs_names, marginal_names, meta, factornode) -+ mapping = MarginalMapping(fform, vtag, msgs_names, marginal_names, meta, node_if_required(fform, factornode)) - # TODO: discontinue operator is needed for loopy belief propagation? Check - marginalout = combineLatest((msgs_observable, marginals_observable), PushNew()) |> discontinue() |> map(Marginal, mapping) - -- connect!(cmarginal, marginalout) # MarginalObservable has RecentSubject by default, there is no need to share_recent() here -+ connect!(cmarginal, marginalout) - - return apply_skip_filter(cmarginal, skip_strategy) - end -@@ -955,7 +955,7 @@ end - make_node(node) - make_node(node, options) - --Creates a factor node of a given type and options. See the list of avaialble factor nodes below. -+Creates a factor node of a given type and options. See the list of available factor nodes below. - - See also: [`@node`](@ref) - -diff --git a/src/nodes/autoregressive.jl b/src/nodes/autoregressive.jl -index 4e80c275..07a620fe 100644 ---- a/src/nodes/autoregressive.jl -+++ b/src/nodes/autoregressive.jl -@@ -1,4 +1,4 @@ --export AR, Autoregressive, ARsafe, ARunsafe, ARMeta, ar_unit, ar_slice -+export AR, Autoregressive, ARsafe, ARunsafe, ARMeta - - import LazyArrays - import Distributions: VariateForm -diff --git a/src/nodes/delta/delta.jl b/src/nodes/delta/delta.jl -index d302ee7c..cfe34369 100644 ---- a/src/nodes/delta/delta.jl -+++ b/src/nodes/delta/delta.jl -@@ -79,8 +79,11 @@ end - # For missing rules error msg - rule_method_error_extract_fform(f::Type{<:DeltaFn}) = "DeltaFn{f}" - -+# `DeltaFn` requires an access to the node function, hence, node reference is required -+call_rule_is_node_required(::Type{<:DeltaFn}) = CallRuleNodeRequired() -+ - # For `@call_rule` and `@call_marginalrule` --function call_rule_make_node(::UndefinedNodeFunctionalForm, fformtype::Type{<:DeltaFn}, nodetype::F, meta::DeltaMeta) where {F} -+function call_rule_make_node(::CallRuleNodeRequired, fformtype::Type{<:DeltaFn}, nodetype::F, meta::DeltaMeta) where {F} - # This node is not initialized properly, but we do not expect rules to access internal uninitialized fields. - # Doing so will most likely throw an error - return DeltaFnNode(nodetype, NodeInterface(:out, Marginalisation()), (), nothing, collect_meta(DeltaFn{F}, meta)) -diff --git a/src/nodes/kernel_gcv.jl b/src/nodes/kernel_gcv.jl -deleted file mode 100644 -index 8434eacc..00000000 ---- a/src/nodes/kernel_gcv.jl -+++ /dev/null -@@ -1,34 +0,0 @@ --export KernelGCV, KernelGCVMetadata -- --import LinearAlgebra: logdet, tr -- --struct KernelGCVMetadata{F, A} -- kernelFn :: F -- approximation :: A --end -- --get_kernelfn(meta::KernelGCVMetadata) = meta.kernelFn --get_approximation(meta::KernelGCVMetadata) = meta.approximation -- --struct KernelGCV end -- --@node KernelGCV Stochastic [y, x, z] -- --# TODO: Remove in favor of Generic Functional Message --struct FnWithApproximation{F, A} -- fn :: F -- approximation :: A --end -- --prod_analytical_rule(::Type{<:MultivariateNormalDistributionsFamily}, ::Type{<:FnWithApproximation}) = ProdAnalyticalRuleAvailable() -- --function prod(::ProdAnalytical, left::MultivariateNormalDistributionsFamily, right::FnWithApproximation) -- μ, Σ = approximate_meancov(right.approximation, (s) -> exp(right.fn(s)), left) -- return MvNormalMeanCovariance(μ, Σ) --end -- --prod_analytical_rule(::Type{<:FnWithApproximation}, ::Type{<:MultivariateNormalDistributionsFamily}) = ProdAnalyticalRuleAvailable() -- --function prod(::ProdAnalytical, left::FnWithApproximation, right::MultivariateNormalDistributionsFamily) -- return prod(ProdAnalytical(), right, left) --end -diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl -deleted file mode 100644 -index 1cc739b5..00000000 ---- a/src/nodes/mv_autoregressive.jl -+++ /dev/null -@@ -1,156 +0,0 @@ --export MAR, MvAutoregressive, MARMeta, mar_transition, mar_shift -- --import LazyArrays, BlockArrays --import StatsFuns: log2π -- --struct MAR end -- --const MvAutoregressive = MAR -- --struct MARMeta -- order :: Int # order (lag) of MAR -- ds :: Int # dimensionality of MAR process, i.e., the number of correlated AR processes -- -- function MARMeta(order, ds = 2) -- if ds < 2 -- @error "ds parameter should be > 1. Use AR node if ds = 1" -- end -- return new(order, ds) -- end --end -- --getorder(meta::MARMeta) = meta.order --getdimensionality(meta::MARMeta) = meta.ds -- --@node MAR Stochastic [y, x, a, Λ] -- --default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag explicitly specified") -- --@average_energy MAR (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::MARMeta) = begin -- ma, Va = mean_cov(q_a) -- myx, Vyx = mean_cov(q_y_x) -- mΛ = mean(q_Λ) -- -- order, ds = getorder(meta), getdimensionality(meta) -- F = Multivariate -- dim = order * ds -- n = div(ndims(q_y_x), 2) -- -- ma, Va = mean_cov(q_a) -- mA = mar_companion_matrix(order, ds, ma)[1:ds, 1:dim] -- -- mx, Vx = ar_slice(F, myx, (dim + 1):(2dim)), ar_slice(F, Vyx, (dim + 1):(2dim), (dim + 1):(2dim)) -- my1, Vy1 = myx[1:ds], Vyx[1:ds, 1:ds] -- Vy1x = ar_slice(F, Vyx, 1:ds, (dim + 1):(2dim)) -- -- # @show Vyx -- # @show Vy1x -- -- # this should be inside MARMeta -- es = [uvector(ds, i) for i in 1:ds] -- Fs = [mask_mar(order, ds, i) for i in 1:ds] -- -- g₁ = my1' * mΛ * my1 + tr(Vy1 * mΛ) -- g₂ = mx' * mA' * mΛ * my1 + tr(Vy1x * mA' * mΛ) -- g₃ = g₂ -- G = sum(sum(es[i]' * mΛ * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:ds) for j in 1:ds) -- g₄ = mx' * G * mx + tr(Vx * G) -- AE = n / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ - g₂ - g₃ + g₄) -- -- if order > 1 -- AE += entropy(q_y_x) -- idc = LazyArrays.Vcat(1:ds, (dim + 1):(2dim)) -- myx_n = view(myx, idc) -- Vyx_n = view(Vyx, idc, idc) -- q_y_x = MvNormalMeanCovariance(myx_n, Vyx_n) -- AE -= entropy(q_y_x) -- end -- -- return AE --end -- --@average_energy MAR ( -- q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::MARMeta --) = begin -- ma, Va = mean_cov(q_a) -- my, Vy = mean_cov(q_y) -- mx, Vx = mean_cov(q_y) -- mΛ = mean(q_Λ) -- -- order, ds = getorder(meta), getdimensionality(meta) -- F = Multivariate -- dim = order * ds -- n = dim -- -- ma, Va = mean_cov(q_a) -- mA = mar_companion_matrix(order, ds, ma)[1:ds, 1:dim] -- -- my1, Vy1 = my[1:ds], Vy[1:ds, 1:ds] -- -- # this should be inside MARMeta -- es = [uvector(ds, i) for i in 1:ds] -- Fs = [mask_mar(order, ds, i) for i in 1:ds] -- -- g₁ = my1' * mΛ * my1 + tr(Vy1 * mΛ) -- g₂ = -mx' * mA' * mΛ * my1 -- g₃ = -g₂ -- G = sum(sum(es[i]' * mΛ * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:ds) for j in 1:ds) -- g₄ = mx' * G * mx + tr(Vx * G) -- AE = n / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ + g₂ + g₃ + g₄) -- -- if order > 1 -- AE += entropy(q_y) -- q_y = MvNormalMeanCovariance(my1, Vy1) -- AE -= entropy(q_y) -- end -- -- return AE --end -- --# Helpers for AR rules --function mask_mar(order, dimension, index) -- F = zeros(dimension * order, dimension * dimension * order) -- rows = repeat([dimension], order) -- cols = repeat([dimension], dimension * order) -- FB = BlockArrays.BlockArray(F, rows, cols) -- for k in 1:order -- for j in 1:(dimension * order) -- if j == index + (k - 1) * dimension -- view(FB, BlockArrays.Block(k, j)) .= diageye(dimension) -- end -- end -- end -- return Matrix(FB) --end -- --function mar_transition(order, Λ) -- dim = size(Λ, 1) -- W = 1.0 * diageye(dim * order) -- W[1:dim, 1:dim] = Λ -- return W --end -- --function mar_shift(order, ds) -- dim = order * ds -- S = diageye(dim) -- for i in dim:-1:(ds + 1) -- S[i, :] = S[i - ds, :] -- end -- S[1:ds, :] = zeros(ds, dim) -- return S --end -- --function uvector(dim, pos = 1) -- u = zeros(dim) -- u[pos] = 1 -- return dim == 1 ? u[pos] : u --end -- --function mar_companion_matrix(order, ds, a) -- dim = order * ds -- S = mar_shift(order, ds) -- es = [uvector(dim, i) for i in 1:ds] -- Fs = [mask_mar(order, ds, i) for i in 1:ds] -- L = S .+ sum(es[i] * a' * Fs[i]' for i in 1:ds) -- return L --end -diff --git a/src/rule.jl b/src/rule.jl -index 4cef52ef..4447e166 100644 ---- a/src/rule.jl -+++ b/src/rule.jl -@@ -202,21 +202,48 @@ function call_rule_macro_parse_fn_args(inputs; specname, prefix, proxy) - return names_arg, values_arg - end - -+# This trait indicates that a node reference is required for a proper rule execution -+# Most of the message passing update rules do not require a node reference -+# An example of a rule that requires a node is the `delta`, that needs the node function -+struct CallRuleNodeRequired end -+ -+# This trait indicates that a node reference is not required for a proper rule execution -+# This is used by default -+struct CallRuleNodeNotRequired end -+ -+""" -+ call_rule_is_node_required(fformtype) -+ -+Returns either `CallRuleNodeRequired()` or `CallRuleNodeNotRequired()` depending on if a specific -+`fformtype` requires an access to the corresponding node in order to compute a message update rule. -+Returns `CallRuleNodeNotRequired()` for all known functional forms by default and `CallRuleNodeRequired()` for all unknown functional forms. -+""" -+call_rule_is_node_required(fformtype) = call_rule_is_node_required(as_node_functional_form(fformtype), fformtype) -+ -+call_rule_is_node_required(::ValidNodeFunctionalForm, fformtype) = CallRuleNodeNotRequired() -+call_rule_is_node_required(::UndefinedNodeFunctionalForm, fformtype) = CallRuleNodeRequired() -+ -+# Returns the `node` if it is required for a rule, otherwise returns `nothing` -+node_if_required(fformtype, node) = node_if_required(call_rule_is_node_required(fformtype), node) -+ -+node_if_required(::CallRuleNodeRequired, node) = node -+node_if_required(::CallRuleNodeNotRequired, node) = nothing -+ - """ - call_rule_create_node(::Type{ NodeType }, fformtype) - --Creates a node object that will be used inside `@call_rule` macro. The node object always creates with the default options for factorisation. -+Creates a node object that will be used inside `@call_rule` macro. - """ - function call_rule_make_node(fformtype, nodetype, meta) -- return call_rule_make_node(ReactiveMP.as_node_functional_form(nodetype), fformtype, nodetype, meta) -+ return call_rule_make_node(call_rule_is_node_required(nodetype), fformtype, nodetype, meta) - end - --function call_rule_make_node(::UndefinedNodeFunctionalForm, fformtype, nodetype, meta) -- return error("Cannot create a node of type `$nodetype` for the call rule routine.") -+function call_rule_make_node(::CallRuleNodeRequired, fformtype, nodetype, meta) -+ return error("Missing implementation for the `call_rule_make_node`. Cannot create a node of type `$nodetype` for the call rule routine.") - end - --function call_rule_make_node(::ValidNodeFunctionalForm, fformtype, nodetype, meta) -- return make_node(nodetype, FactorNodeCreationOptions(nothing, meta, nothing)) -+function call_rule_make_node(::CallRuleNodeNotRequired, fformtype, nodetype, meta) -+ return nothing - end - - call_rule_macro_construct_on_arg(on_type, on_index::Nothing) = MacroHelpers.bottom_type(on_type) -diff --git a/src/rules/bernoulli/marginals.jl b/src/rules/bernoulli/marginals.jl -index 78146875..40a68b2b 100644 ---- a/src/rules/bernoulli/marginals.jl -+++ b/src/rules/bernoulli/marginals.jl -@@ -5,3 +5,7 @@ export marginalrule - p = prod(ProdAnalytical(), Beta(one(r) + r, 2one(r) - r), m_p) - return (out = m_out, p = p) - end -+ -+@marginalrule Bernoulli(:out_p) (m_out::Bernoulli, m_p::PointMass) = begin -+ return (out = prod(ProdAnalytical(), Bernoulli(mean(m_p)), m_out), p = m_p) -+end -diff --git a/src/rules/categorical/marginals.jl b/src/rules/categorical/marginals.jl -index 4befdbb1..88c8890f 100644 ---- a/src/rules/categorical/marginals.jl -+++ b/src/rules/categorical/marginals.jl -@@ -2,3 +2,8 @@ - @marginalrule Categorical(:out_p) (m_out::Categorical, m_p::PointMass) = begin - return (out = prod(ProdAnalytical(), Categorical(mean(m_p)), m_out), p = m_p) - end -+ -+@marginalrule Categorical(:out_p) (m_out::PointMass, m_p::Dirichlet) = begin -+ p = prod(ProdAnalytical(), Dirichlet(probvec(m_out) .+ one(eltype(probvec(m_out)))), m_p) -+ return (out = m_out, p = p) -+end -diff --git a/src/rules/kernel_gcv/marginals.jl b/src/rules/kernel_gcv/marginals.jl -deleted file mode 100644 -index 9745498f..00000000 ---- a/src/rules/kernel_gcv/marginals.jl -+++ /dev/null -@@ -1,33 +0,0 @@ --export marginalrule -- --@marginalrule KernelGCV(:y_x) (m_y::MvNormalMeanCovariance, m_x::MvNormalMeanCovariance, q_z::MvNormalMeanCovariance, meta::KernelGCVMetadata) = begin -- kernelfunction = get_kernelfn(meta) -- Λ = approximate_kernel_expectation(get_approximation(meta), (z) -> cholinv(kernelfunction(z)), q_z) -- -- Λy = invcov(m_y) -- Λx = invcov(m_x) -- -- wy = Λy * mean(m_y) -- wx = Λx * mean(m_x) -- -- C = cholinv([Λ+Λy -Λ; -Λ Λ+Λx]) -- m = C * [wy; wx] -- -- return MvNormalMeanCovariance(m, C) --end -- --@marginalrule KernelGCV(:y_x) (m_y::MvNormalMeanPrecision, m_x::MvNormalMeanPrecision, q_z::MvNormalMeanPrecision, meta::KernelGCVMetadata) = begin -- kernelfunction = get_kernelfn(meta) -- C = approximate_kernel_expectation(get_approximation(meta), (z) -> cholinv(kernelfunction(z)), q_z) -- -- Cy = invcov(m_y) -- Cx = invcov(m_x) -- -- wy = Cy * mean(m_y) -- wx = Cx * mean(m_x) -- -- Λ = [C+Cy -C; -C C+Cx] -- μ = cholinv(Λ) * [wy; wx] -- -- return MvNormalMeanPrecision(μ, Λ) --end -diff --git a/src/rules/kernel_gcv/x.jl b/src/rules/kernel_gcv/x.jl -deleted file mode 100644 -index 29776006..00000000 ---- a/src/rules/kernel_gcv/x.jl -+++ /dev/null -@@ -1,13 +0,0 @@ --export rule -- --@rule KernelGCV(:x, Marginalisation) (m_y::MvNormalMeanCovariance, q_z::MvNormalMeanCovariance, meta::KernelGCVMetadata) = begin -- kernelfunction = get_kernelfn(meta) -- Λ_out = approximate_kernel_expectation(get_approximation(meta), (s) -> cholinv(kernelfunction(s)), q_z) -- return MvNormalMeanCovariance(mean(m_y), cov(m_y) + cholinv(Λ_out)) --end -- --@rule KernelGCV(:x, Marginalisation) (m_y::MvNormalMeanPrecision, q_z::MvNormalMeanPrecision, meta::KernelGCVMetadata) = begin -- kernelfunction = get_kernelfn(meta) -- Λ_out = approximate_kernel_expectation(get_approximation(meta), (s) -> cholinv(kernelfunction(s)), q_z) -- return MvNormalMeanPrecision(mean(m_y), cholinv(cov(m_y) + cholinv(Λ_out))) --end -diff --git a/src/rules/kernel_gcv/y.jl b/src/rules/kernel_gcv/y.jl -deleted file mode 100644 -index a6a4b82f..00000000 ---- a/src/rules/kernel_gcv/y.jl -+++ /dev/null -@@ -1,13 +0,0 @@ --export rule -- --@rule KernelGCV(:y, Marginalisation) (m_x::MvNormalMeanCovariance, q_z::MvNormalMeanCovariance, meta::KernelGCVMetadata) = begin -- kernelfunction = get_kernelfn(meta) -- Λ_out = approximate_kernel_expectation(get_approximation(meta), (s) -> cholinv(kernelfunction(s)), q_z) -- return MvNormalMeanCovariance(mean(m_x), cov(m_x) + cholinv(Λ_out)) --end -- --@rule KernelGCV(:y, Marginalisation) (m_x::MvNormalMeanPrecision, q_z::MvNormalMeanPrecision, meta::KernelGCVMetadata) = begin -- kernelfunction = get_kernelfn(meta) -- Λ_out = approximate_kernel_expectation(get_approximation(meta), (s) -> inv(kernelfunction(s)), q_z) -- return MvNormalMeanPrecision(mean(m_x), cholinv(cov(m_x) + cholinv(Λ_out))) --end -diff --git a/src/rules/kernel_gcv/z.jl b/src/rules/kernel_gcv/z.jl -deleted file mode 100644 -index 5ebb4978..00000000 ---- a/src/rules/kernel_gcv/z.jl -+++ /dev/null -@@ -1,53 +0,0 @@ --export rule -- --@rule KernelGCV(:z, Marginalisation) (q_y_x::MvNormalMeanCovariance, meta::KernelGCVMetadata) = begin -- dims = Int64(ndims(q_y_x) / 2) -- -- m_yx = mean(q_y_x) -- cov_yx = cov(q_y_x) -- -- cov11 = @view cov_yx[1:dims, 1:dims] -- cov12 = @view cov_yx[1:dims, (dims + 1):end] -- cov21 = @view cov_yx[(dims + 1):end, 1:dims] -- cov22 = @view cov_yx[(dims + 1):end, (dims + 1):end] -- -- m1 = @view m_yx[1:dims] -- m2 = @view m_yx[(dims + 1):end] -- -- psi = cov11 + cov22 - cov12 - cov21 + (m1 - m2) * (m1 - m2)' -- -- kernelfunction = get_kernelfn(meta) -- -- logpdf = (z) -> begin -- gz = kernelfunction(z) -- -0.5 * (logdet(gz) + tr(cholinv(gz) * psi)) -- end -- -- return FnWithApproximation(logpdf, get_approximation(meta)) --end -- --@rule KernelGCV(:z, Marginalisation) (q_y_x::MvNormalMeanPrecision, meta::KernelGCVMetadata) = begin -- dims = Int64(ndims(q_y_x) / 2) -- -- m_yx = mean(q_y_x) -- cov_yx = cov(q_y_x) -- -- cov11 = @view cov_yx[1:dims, 1:dims] -- cov12 = @view cov_yx[1:dims, (dims + 1):end] -- cov21 = @view cov_yx[(dims + 1):end, 1:dims] -- cov22 = @view cov_yx[(dims + 1):end, (dims + 1):end] -- -- m1 = @view m_yx[1:dims] -- m2 = @view m_yx[(dims + 1):end] -- -- psi = cov11 + cov22 - cov12 - cov21 + (m1 - m2) * (m1 - m2)' -- -- kernelfunction = get_kernelfn(meta) -- -- logpdf = (z) -> begin -- gz = kernelfunction(z) -- -0.5 * (logdet(gz) + tr(cholinv(gz) * psi)) -- end -- -- return FnWithApproximation(logpdf, get_approximation(meta)) --end -diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl -deleted file mode 100644 -index 0777efb7..00000000 ---- a/src/rules/mv_autoregressive/a.jl -+++ /dev/null -@@ -1,50 +0,0 @@ -- --@rule MAR(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin -- order, ds = getorder(meta), getdimensionality(meta) -- F = Multivariate -- -- dim = order * ds -- -- m, V = mean_cov(q_y_x) -- -- my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) -- mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) -- Vyx = ar_slice(F, V, 1:dim, (dim + 1):(2dim)) -- -- mΛ = mean(q_Λ) -- mW = mar_transition(order, mΛ) -- -- # this should be inside MARMeta -- es = [uvector(dim, i) for i in 1:ds] -- Fs = [mask_mar(order, ds, i) for i in 1:ds] -- S = mar_shift(order, ds) -- -- # NOTE: prove that sum(Fs[i]'*((mx*mx'+Vx')*S')*mW*es[i] for i in 1:ds) == 0.0 -- D = sum(sum(es[i]' * mW * es[j] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:ds) for j in 1:ds) -- z = sum(Fs[i]' * (mx * my' + Vyx') * mW * es[i] for i in 1:ds) -- -- return MvNormalWeightedMeanPrecision(z, D) --end -- --@rule MAR(:a, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin -- order, ds = getorder(meta), getdimensionality(meta) -- F = Multivariate -- -- dim = order * ds -- -- my, Vy = mean_cov(q_y) -- mx, Vx = mean_cov(q_x) -- mΛ = mean(q_Λ) -- -- mW = mar_transition(order, mΛ) -- S = mar_shift(order, ds) -- -- # this should be inside MARMeta -- es = [uvector(dim, i) for i in 1:ds] -- Fs = [mask_mar(order, ds, i) for i in 1:ds] -- -- D = sum(sum(es[j]' * mW * es[i] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:ds) for j in 1:ds) -- z = sum(Fs[i]' * ((mx * mx' + Vx') * S' + mx * my') * mW * es[i] for i in 1:ds) -- -- return MvNormalWeightedMeanPrecision(z, D) --end -diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl -deleted file mode 100644 -index 29f88cba..00000000 ---- a/src/rules/mv_autoregressive/lambda.jl -+++ /dev/null -@@ -1,53 +0,0 @@ --@rule MAR(:Λ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::MARMeta) = begin -- order, ds = getorder(meta), getdimensionality(meta) -- F = Multivariate -- dim = order * ds -- -- ma, Va = mean_cov(q_a) -- -- mA = mar_companion_matrix(order, ds, ma) -- -- m, V = mean_cov(q_y_x) -- my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) -- mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) -- Vyx = ar_slice(F, V, 1:dim, (dim + 1):(2dim)) -- -- es = [uvector(dim, i) for i in 1:ds] -- Fs = [mask_mar(order, ds, i) for i in 1:ds] -- S = mar_shift(order, ds) -- G₁ = (my * my' + Vy)[1:ds, 1:ds] -- G₂ = ((my * mx' + Vyx) * mA')[1:ds, 1:ds] -- G₃ = transpose(G₂) -- Ex_xx = mx * mx' + Vx -- G₅ = sum(sum(es[i] * ma' * Fs[i]'Ex_xx * Fs[j] * ma * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] -- G₆ = sum(sum(es[i] * tr(Fs[i]' * Ex_xx * Fs[j] * Va) * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] -- Δ = G₁ - G₂ - G₃ + G₅ + G₆ -- -- return WishartMessage(ds + 2, Δ) --end -- --@rule MAR(:Λ, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::MARMeta) = -- begin -- order, ds = getorder(meta), getdimensionality(meta) -- F = Multivariate -- dim = order * ds -- -- my, Vy = mean_cov(q_y) -- mx, Vx = mean_cov(q_x) -- ma, Va = mean_cov(q_a) -- -- mA = mar_companion_matrix(order, ds, ma) -- -- es = [uvector(dim, i) for i in 1:ds] -- Fs = [mask_mar(order, ds, i) for i in 1:ds] -- S = mar_shift(order, ds) -- G₁ = (my * my' + Vy)[1:ds, 1:ds] -- G₂ = (my * mx' * mA')[1:ds, 1:ds] -- G₃ = transpose(G₂) -- Ex_xx = mx * mx' + Vx -- G₅ = sum(sum(es[i] * ma' * Fs[j]'Ex_xx * Fs[i] * ma * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] -- G₆ = sum(sum(es[i] * tr(Va * Fs[i]' * Ex_xx * Fs[j]) * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] -- Δ = G₁ - G₂ - G₃ + G₅ + G₆ -- -- return WishartMessage(ds + 2, Δ) -- end -diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl -deleted file mode 100644 -index 92a71a65..00000000 ---- a/src/rules/mv_autoregressive/marginals.jl -+++ /dev/null -@@ -1,46 +0,0 @@ -- --@marginalrule MAR(:y_x) ( -- m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta --) = begin -- return ar_y_x_marginal(m_y, m_x, q_a, q_Λ, meta) --end -- --function ar_y_x_marginal( -- m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta --) -- order, ds = getorder(meta), getdimensionality(meta) -- F = Multivariate -- dim = order * ds -- -- ma, Va = mean_cov(q_a) -- mΛ = mean(q_Λ) -- -- mA = mar_companion_matrix(order, ds, ma) -- mW = mar_transition(getorder(meta), mΛ) -- -- b_my, b_Vy = mean_cov(m_y) -- f_mx, f_Vx = mean_cov(m_x) -- -- inv_b_Vy = cholinv(b_Vy) -- inv_f_Vx = cholinv(f_Vx) -- -- # this should be inside MARMeta -- es = [uvector(dim, i) for i in 1:ds] -- Fs = [mask_mar(order, ds, i) for i in 1:ds] -- -- Ξ = inv_f_Vx + sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) -- -- W_11 = inv_b_Vy + mW -- -- # negate_inplace!(mW * mA) -- W_12 = -(mW * mA) -- -- W_21 = -(mA' * mW) -- -- W_22 = Ξ + mA' * mW * mA -- -- W = [W_11 W_12; W_21 W_22] -- ξ = [inv_b_Vy * b_my; inv_f_Vx * f_mx] -- -- return MvNormalWeightedMeanPrecision(ξ, W) --end -diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl -deleted file mode 100644 -index e191589a..00000000 ---- a/src/rules/mv_autoregressive/x.jl -+++ /dev/null -@@ -1,50 +0,0 @@ -- --@rule MAR(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin -- ma, Va = mean_cov(q_a) -- my, Vy = mean_cov(m_y) -- -- mΛ = mean(q_Λ) -- -- order, ds = getorder(meta), getdimensionality(meta) -- dim = order * ds -- -- mA = mar_companion_matrix(order, ds, ma) -- mW = mar_transition(getorder(meta), mΛ) -- # this should be inside MARMeta -- es = [uvector(dim, i) for i in 1:ds] -- Fs = [mask_mar(order, ds, i) for i in 1:ds] -- -- Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) -- -- Σ₁ = Hermitian(pinv(mA) * (Vy) * pinv(mA') + pinv(mA' * mW * mA)) -- -- Ξ = (pinv(Σ₁) + Λ) -- z = pinv(Σ₁) * pinv(mA) * my -- -- return MvNormalWeightedMeanPrecision(z, Ξ) --end -- --@rule MAR(:x, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin -- ma, Va = mean_cov(q_a) -- my, Vy = mean_cov(q_y) -- -- mΛ = mean(q_Λ) -- -- order, ds = getorder(meta), getdimensionality(meta) -- dim = order * ds -- -- mA = mar_companion_matrix(order, ds, ma) -- mW = mar_transition(getorder(meta), mΛ) -- -- # this should be inside MARMeta -- es = [uvector(dim, i) for i in 1:ds] -- Fs = [mask_mar(order, ds, i) for i in 1:ds] -- -- Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) -- Λ₀ = Hermitian(mA' * mW * mA) -- -- Ξ = Λ₀ + Λ -- z = Λ₀ * pinv(mA) * my -- -- return MvNormalWeightedMeanPrecision(z, Ξ) --end -diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl -deleted file mode 100644 -index b99ace9b..00000000 ---- a/src/rules/mv_autoregressive/y.jl -+++ /dev/null -@@ -1,34 +0,0 @@ --@rule MAR(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin -- ma, Va = mean_cov(q_a) -- mx, Wx = mean_invcov(m_x) -- -- mΛ = mean(q_Λ) -- -- order, ds = getorder(meta), getdimensionality(meta) -- -- mA = mar_companion_matrix(order, ds, ma) -- mW = mar_transition(getorder(meta), mΛ) -- dim = order * ds -- # this should be inside MARMeta -- es = [uvector(dim, i) for i in 1:ds] -- Fs = [mask_mar(order, ds, i) for i in 1:ds] -- -- Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) -- -- Ξ = Λ + Wx -- z = Wx * mx -- -- Vy = mA * inv(Ξ) * mA' + inv(mW) -- my = mA * inv(Ξ) * z -- -- return MvNormalMeanCovariance(my, Vy) --end -- --@rule MAR(:y, Marginalisation) (q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin -- order, ds = getorder(meta), getdimensionality(meta) -- -- mA = mar_companion_matrix(order, ds, mean(q_a)) -- mW = mar_transition(getorder(meta), mean(q_Λ)) -- -- return MvNormalMeanPrecision(mA * mean(q_x), mW) --end -diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl -index 81395dd7..5c56ab65 100644 ---- a/src/rules/prototypes.jl -+++ b/src/rules/prototypes.jl -@@ -52,11 +52,6 @@ include("gcv/w.jl") - include("gcv/marginals.jl") - include("gcv/gaussian_extension.jl") - --include("kernel_gcv/x.jl") --include("kernel_gcv/y.jl") --include("kernel_gcv/z.jl") --include("kernel_gcv/marginals.jl") -- - include("mv_normal_mean_covariance/out.jl") - include("mv_normal_mean_covariance/mean.jl") - include("mv_normal_mean_covariance/covariance.jl") -@@ -116,12 +111,6 @@ include("autoregressive/theta.jl") - include("autoregressive/gamma.jl") - include("autoregressive/marginals.jl") - --include("mv_autoregressive/y.jl") --include("mv_autoregressive/x.jl") --include("mv_autoregressive/a.jl") --include("mv_autoregressive/lambda.jl") --include("mv_autoregressive/marginals.jl") -- - include("probit/marginals.jl") - include("probit/in.jl") - include("probit/out.jl") -diff --git a/src/variables/constant.jl b/src/variables/constant.jl -index 6cb40aed..b808acd6 100644 ---- a/src/variables/constant.jl -+++ b/src/variables/constant.jl -@@ -50,8 +50,8 @@ function constvar end - - constvar(name::Symbol, constval, collection_type::AbstractVariableCollectionType = VariableIndividual()) = ConstVariable(name, collection_type, constval, of(Message(constval, true, false, nothing)), 0) - constvar(name::Symbol, constval::Real, collection_type::AbstractVariableCollectionType = VariableIndividual()) = constvar(name, PointMass(constval), collection_type) --constvar(name::Symbol, constval::AbstractVector, collection_type::AbstractVariableCollectionType = VariableIndividual()) = constvar(name, PointMass(constval), collection_type) --constvar(name::Symbol, constval::AbstractMatrix, collection_type::AbstractVariableCollectionType = VariableIndividual()) = constvar(name, PointMass(constval), collection_type) -+constvar(name::Symbol, constval::AbstractArray, collection_type::AbstractVariableCollectionType = VariableIndividual()) = constvar(name, PointMass(constval), collection_type) -+constvar(name::Symbol, constval::UniformScaling, collection_type::AbstractVariableCollectionType = VariableIndividual()) = constvar(name, PointMass(constval), collection_type) - - function constvar(name::Symbol, fn::Function, length::Int) - return map(i -> constvar(name, fn(i), VariableVector(i)), 1:length) -diff --git a/src/variables/data.jl b/src/variables/data.jl -index beb05410..6600cda8 100644 ---- a/src/variables/data.jl -+++ b/src/variables/data.jl -@@ -9,15 +9,19 @@ mutable struct DataVariable{D, S} <: AbstractVariable - input_messages :: Vector{MessageObservable{AbstractMessage}} - messageout :: S - nconnected :: Int -+ isproxy :: Bool -+ isused :: Bool - end - - Base.show(io::IO, datavar::DataVariable) = print(io, "DataVariable(", indexed_name(datavar), ")") - - struct DataVariableCreationOptions{S} -- subject::S -+ subject :: S -+ isproxy :: Bool -+ isused :: Bool - end - --Base.similar(options::DataVariableCreationOptions) = DataVariableCreationOptions(similar(options.subject)) -+Base.similar(options::DataVariableCreationOptions) = DataVariableCreationOptions(similar(options.subject), options.isproxy, options.isused) - - DataVariableCreationOptions(::Type{D}) where {D} = DataVariableCreationOptions(D, nothing) - DataVariableCreationOptions(::Type{D}, subject) where {D} = DataVariableCreationOptions(D, subject, Val(false)) -@@ -26,7 +30,7 @@ DataVariableCreationOptions(::Type{D}, subject::Nothing, allow_missing::Val{true - DataVariableCreationOptions(::Type{D}, subject::Nothing, allow_missing::Val{false}) where {D} = DataVariableCreationOptions(D, RecentSubject(Union{Message{D}}), Val(false)) - - DataVariableCreationOptions(::Type{D}, subject::S, ::Val{true}) where {D, S} = error("Error in datavar options. Custom `subject` was specified and `allow_missing` was set to true, which is disallowed. Provide a custom subject that accept missing values by itself and do no use `allow_missing` option.") --DataVariableCreationOptions(::Type{D}, subject::S, ::Val{false}) where {D, S} = DataVariableCreationOptions{S}(subject) -+DataVariableCreationOptions(::Type{D}, subject::S, ::Val{false}) where {D, S} = DataVariableCreationOptions{S}(subject, false, false) - - """ - datavar(::Type, [ dims... ]) -@@ -72,7 +76,7 @@ datavar(name::Symbol, ::Type{D}, dims::Tuple) where {D} - datavar(name::Symbol, ::Type{D}, dims::Vararg{Int}) where {D} = datavar(DataVariableCreationOptions(D), name, D, dims) - - datavar(options::DataVariableCreationOptions{S}, name::Symbol, ::Type{D}, collection_type::AbstractVariableCollectionType = VariableIndividual()) where {S, D} = -- DataVariable{D, S}(name, collection_type, MarginalObservable(), Vector{MessageObservable{AbstractMessage}}(), options.subject, 0) -+ DataVariable{D, S}(name, collection_type, MarginalObservable(), Vector{MessageObservable{AbstractMessage}}(), options.subject, 0, options.isproxy, options.isused) - - function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, length::Int) where {D} - return map(i -> datavar(similar(options), name, D, VariableVector(i)), 1:length) -@@ -93,12 +97,13 @@ Base.eltype(::DataVariable{D}) where {D} = D - - degree(datavar::DataVariable) = nconnected(datavar) - name(datavar::DataVariable) = datavar.name --proxy_variables(datavar::DataVariable) = nothing -+proxy_variables(datavar::DataVariable) = nothing # not related to isproxy - collection_type(datavar::DataVariable) = datavar.collection_type - isconnected(datavar::DataVariable) = datavar.nconnected !== 0 - nconnected(datavar::DataVariable) = datavar.nconnected - --isproxy(::DataVariable) = false -+isproxy(datavar::DataVariable) = datavar.isproxy -+isused(datavar::DataVariable) = datavar.isused - - israndom(::DataVariable) = false - israndom(::AbstractArray{<:DataVariable}) = false -@@ -117,7 +122,7 @@ function Base.getindex(datavar::DataVariable, i...) - error("Variable $(indexed_name(datavar)) has been indexed with `[$(join(i, ','))]`. Direct indexing of `data` variables is not allowed.") - end - --getlastindex(::DataVariable) = 1 -+getlastindex(datavar::DataVariable) = degree(datavar) + 1 - - messageout(datavar::DataVariable, ::Int) = datavar.messageout - messagein(datavar::DataVariable, ::Int) = error("It is not possible to get a reference for inbound message for datavar") -@@ -168,16 +173,18 @@ _getmarginal(datavar::DataVariable) = datavar.messageout |> map(Mar - _setmarginal!(datavar::DataVariable, observable) = error("It is not possible to set a marginal stream for `DataVariable`") - _makemarginal(datavar::DataVariable) = error("It is not possible to make marginal stream for `DataVariable`") - --# Extension for _getmarginal --function Rocket.getrecent(proxy::ProxyObservable{<:Marginal, S, M}) where {S <: Rocket.RecentSubjectInstance, D, M <: Rocket.MapProxy{D, typeof(as_marginal)}} -- return as_marginal(Rocket.getrecent(proxy.proxied_source)) --end -- - setanonymous!(::DataVariable, ::Bool) = nothing - --function setmessagein!(datavar::DataVariable, ::Int, messagein) -- datavar.nconnected += 1 -- push!(datavar.input_messages, messagein) -+function setmessagein!(datavar::DataVariable, index::Int, messagein) -+ if index === (degree(datavar) + 1) -+ push!(datavar.input_messages, messagein) -+ datavar.nconnected += 1 -+ datavar.isused = true -+ else -+ error( -+ "Inconsistent state in setmessagein! function for data variable $(datavar). `index` should be equal to `degree(datavar) + 1 = $(degree(datavar) + 1)`, $(index) is given instead" -+ ) -+ end - return nothing - end - -diff --git a/src/variables/variable.jl b/src/variables/variable.jl -index 5d2be480..bf990a7b 100644 ---- a/src/variables/variable.jl -+++ b/src/variables/variable.jl -@@ -147,7 +147,7 @@ track of `proxy_variables`. During the first call of `get_factorisation_referenc - 2. if yes we pass it futher to the `unchecked` version of the function - 2.1 `unchecked` version return immediatelly if there is only one proxy var (see bullet 1) - 2.2 in case of multiple proxy vars we filter only `RandomVariable` and call `checked` version of the function --3. `checked` version of the function return immediatelly if there is only one proxy random variable left, if there are multuple proxy random vars we throw an error as this case is ambigous for factorisation constrains specification -+3. `checked` version of the function return immediatelly if there is only one proxy random variable left, if there are multiple proxy random vars we throw an error as this case is ambigous for factorisation constraints specification - - This function is a part of private API and should not be used explicitly. - """ -diff --git a/test/approximations/test_cvi.jl b/test/approximations/test_cvi.jl -index fbc9f3b7..cf6a1df0 100644 ---- a/test/approximations/test_cvi.jl -+++ b/test/approximations/test_cvi.jl -@@ -80,8 +80,8 @@ end - rng = StableRNG(42) - - tests = ( -- (method = CVI(StableRNG(42), 1, 1000, Descent(0.01), ForwardDiffGrad(), 1, Val(true), false), tol = 5e-1), -- (method = CVI(StableRNG(42), 1, 1000, Descent(0.01), ZygoteGrad(), 1, Val(true), false), tol = 5e-1) -+ (method = CVI(StableRNG(42), 1, 1000, Descent(0.01), ForwardDiffGrad(), 10, Val(true), false), tol = 5e-1), -+ (method = CVI(StableRNG(42), 1, 1000, Descent(0.01), ZygoteGrad(), 10, Val(true), false), tol = 5e-1) - ) - - # Check several prods against their analytical solutions -@@ -122,9 +122,16 @@ end - - b1 = Bernoulli(logistic(randn(rng))) - b2 = Bernoulli(logistic(randn(rng))) -- b_analitical = prod(ProdAnalytical(), b1, b2) -+ b_analytical = prod(ProdAnalytical(), b1, b2) - b_cvi = prod(test[:method], b1, b1) -- @test isapprox(mean(b_analitical), mean(b_cvi), atol = test[:tol]) -+ @test isapprox(mean(b_analytical), mean(b_cvi), atol = test[:tol]) -+ -+ beta_1 = Beta(abs(randn(rng)) + 1, abs(randn(rng)) + 1) -+ beta_2 = Beta(abs(randn(rng)) + 1, abs(randn(rng)) + 1) -+ -+ beta_analytical = prod(ProdAnalytical(), beta_1, beta_2) -+ beta_cvi = prod(test[:method], beta_1, beta_2) -+ @test isapprox(mean(beta_analytical), mean(beta_cvi), atol = test[:tol]) - end - end - -diff --git a/test/distributions/test_bernoulli.jl b/test/distributions/test_bernoulli.jl -index 2abc4968..8a469cd6 100644 ---- a/test/distributions/test_bernoulli.jl -+++ b/test/distributions/test_bernoulli.jl -@@ -26,6 +26,18 @@ using ReactiveMP: compute_logscale - @test prod(ProdAnalytical(), Bernoulli(0.78), Bernoulli(0.05)) ≈ Bernoulli(0.1572580645161291) - end - -+ @testset "prod Bernoulli-Categorical" begin -+ @test prod(ProdAnalytical(), Bernoulli(0.5), Categorical([1.0])) == Categorical([1.0, 0.0]) -+ @test prod(ProdAnalytical(), Bernoulli(0.6), Categorical([0.7, 0.3])) == Categorical([0.6086956521739131, 0.391304347826087]) -+ @test prod(ProdAnalytical(), Bernoulli(0.8), Categorical([0.2, 0.4, 0.4])) == Categorical([0.11111111111111108, 0.8888888888888888, 0.0]) -+ end -+ -+ @testset "prod Categorical-Bernoulli" begin -+ @test prod(ProdAnalytical(), Categorical([1.0]), Bernoulli(0.5)) == Categorical([1.0, 0.0]) -+ @test prod(ProdAnalytical(), Categorical([0.7, 0.3]), Bernoulli(0.6)) == Categorical([0.6086956521739131, 0.391304347826087]) -+ @test prod(ProdAnalytical(), Categorical([0.2, 0.4, 0.4]), Bernoulli(0.8)) == Categorical([0.11111111111111108, 0.8888888888888888, 0.0]) -+ end -+ - @testset "probvec" begin - @test probvec(Bernoulli(0.5)) === (0.5, 0.5) - @test probvec(Bernoulli(0.3)) === (0.7, 0.3) -diff --git a/test/distributions/test_beta.jl b/test/distributions/test_beta.jl -index 0510f61c..14b4f5ab 100644 ---- a/test/distributions/test_beta.jl -+++ b/test/distributions/test_beta.jl -@@ -6,6 +6,7 @@ using Distributions - using Random - - import ReactiveMP: mirrorlog -+import SpecialFunctions: loggamma - - @testset "Beta" begin - -@@ -37,6 +38,38 @@ import ReactiveMP: mirrorlog - @test mean(mirrorlog, Beta(0.1, 0.3)) ≈ -0.9411396776150167 - @test mean(mirrorlog, Beta(4.5, 0.3)) ≈ -4.963371962929249 - end -+ -+ @testset "BetaNaturalParameters" begin -+ @testset "Constructor" begin -+ for i in 0:10, j in 0:10 -+ @test convert(Distribution, BetaNaturalParameters(i, j)) == Beta(i + 1, j + 1) -+ -+ @test convert(BetaNaturalParameters, i, j) == BetaNaturalParameters(i, j) -+ @test convert(BetaNaturalParameters, [i, j]) == BetaNaturalParameters(i, j) -+ end -+ end -+ -+ @testset "lognormalizer" begin -+ @test lognormalizer(BetaNaturalParameters(0, 0)) ≈ 0 -+ @test lognormalizer(BetaNaturalParameters(1, 1)) ≈ -loggamma(4) -+ end -+ -+ @testset "logpdf" begin -+ for i in 0:10, j in 0:10 -+ @test logpdf(BetaNaturalParameters(i, j), 0.01) ≈ logpdf(Beta(i + 1, j + 1), 0.01) -+ @test logpdf(BetaNaturalParameters(i, j), 0.5) ≈ logpdf(Beta(i + 1, j + 1), 0.5) -+ end -+ end -+ -+ @testset "isproper" begin -+ for i in 0:10 -+ @test isproper(BetaNaturalParameters(i, i)) === true -+ end -+ for i in 1:10 -+ @test isproper(BetaNaturalParameters(-i, -i)) === false -+ end -+ end -+ end - end - - end -diff --git a/test/distributions/test_mv_normal_mean_covariance.jl b/test/distributions/test_mv_normal_mean_covariance.jl -index ef9af84d..55b4d4dd 100644 ---- a/test/distributions/test_mv_normal_mean_covariance.jl -+++ b/test/distributions/test_mv_normal_mean_covariance.jl -@@ -14,6 +14,13 @@ using Distributions - @test MvNormalMeanCovariance([1, 2]) == MvNormalMeanCovariance([1.0, 2.0], [1.0, 1.0]) - @test MvNormalMeanCovariance([1.0f0, 2.0f0]) == MvNormalMeanCovariance([1.0f0, 2.0f0], [1.0f0, 1.0f0]) - -+ # uniformscaling -+ @test MvNormalMeanCovariance([1, 2], I) == MvNormalMeanCovariance([1, 2], Diagonal([1, 1])) -+ @test MvNormalMeanCovariance([1, 2], 6 * I) == MvNormalMeanCovariance([1, 2], Diagonal([6, 6])) -+ @test MvNormalMeanCovariance([1.0, 2.0], I) == MvNormalMeanCovariance([1.0, 2.0], Diagonal([1.0, 1.0])) -+ @test MvNormalMeanCovariance([1.0, 2.0], 6 * I) == MvNormalMeanCovariance([1.0, 2.0], Diagonal([6.0, 6.0])) -+ @test MvNormalMeanCovariance([1, 2], 6.0 * I) == MvNormalMeanCovariance([1.0, 2.0], Diagonal([6.0, 6.0])) -+ - @test eltype(MvNormalMeanCovariance([1.0, 1.0])) === Float64 - @test eltype(MvNormalMeanCovariance([1.0, 1.0], [1.0, 1.0])) === Float64 - @test eltype(MvNormalMeanCovariance([1, 1])) === Float64 -@@ -91,6 +98,14 @@ using Distributions - dist = MvNormalMeanCovariance(μ, Σ) - - @test prod(ProdAnalytical(), dist, dist) ≈ MvNormalWeightedMeanPrecision([2.0, 2.0, 2.0], diagm([2.0, 1.0, 2 / 3])) -+ -+ # diagonal covariance matrix/uniformscaling -+ @test prod(ProdAnalytical(), MvNormalMeanCovariance([-1, -1], [2 0; 0 2]), MvNormalMeanCovariance([1, 1], Diagonal([2, 4]))) ≈ -+ MvNormalWeightedMeanPrecision([0, -1 / 4], [1, 3 / 4]) -+ @test prod(ProdAnalytical(), MvNormalMeanCovariance([-1, -1], [2, 2]), MvNormalMeanCovariance([1, 1], Diagonal([2, 4]))) ≈ -+ MvNormalWeightedMeanPrecision([0, -1 / 4], [1, 3 / 4]) -+ @test prod(ProdAnalytical(), MvNormalMeanCovariance([-1, -1], 2 * I), MvNormalMeanCovariance([1, 1], Diagonal([2, 4]))) ≈ -+ MvNormalWeightedMeanPrecision([0, -1 / 4], [1, 3 / 4]) - end - - @testset "Primitive types conversion" begin -diff --git a/test/distributions/test_mv_normal_mean_precision.jl b/test/distributions/test_mv_normal_mean_precision.jl -index 97afa3b5..4f00fd2e 100644 ---- a/test/distributions/test_mv_normal_mean_precision.jl -+++ b/test/distributions/test_mv_normal_mean_precision.jl -@@ -14,6 +14,13 @@ using Distributions - @test MvNormalMeanPrecision([1, 2]) == MvNormalMeanPrecision([1.0, 2.0], [1.0, 1.0]) - @test MvNormalMeanPrecision([1.0f0, 2.0f0]) == MvNormalMeanPrecision([1.0f0, 2.0f0], [1.0f0, 1.0f0]) - -+ # uniformscaling -+ @test MvNormalMeanPrecision([1, 2], I) == MvNormalMeanPrecision([1, 2], Diagonal([1, 1])) -+ @test MvNormalMeanPrecision([1, 2], 6 * I) == MvNormalMeanPrecision([1, 2], Diagonal([6, 6])) -+ @test MvNormalMeanPrecision([1.0, 2.0], I) == MvNormalMeanPrecision([1.0, 2.0], Diagonal([1.0, 1.0])) -+ @test MvNormalMeanPrecision([1.0, 2.0], 6 * I) == MvNormalMeanPrecision([1.0, 2.0], Diagonal([6.0, 6.0])) -+ @test MvNormalMeanPrecision([1, 2], 6.0 * I) == MvNormalMeanPrecision([1.0, 2.0], Diagonal([6.0, 6.0])) -+ - @test eltype(MvNormalMeanPrecision([1.0, 1.0])) === Float64 - @test eltype(MvNormalMeanPrecision([1.0, 1.0], [1.0, 1.0])) === Float64 - @test eltype(MvNormalMeanPrecision([1, 1])) === Float64 -@@ -91,6 +98,11 @@ using Distributions - dist = MvNormalMeanPrecision(μ, Λ) - - @test prod(ProdAnalytical(), dist, dist) ≈ MvNormalWeightedMeanPrecision([2.0, 2.0, 2.0], diagm([2.0, 1.0, 2 / 3])) -+ -+ # diagonal covariance matrix/uniformscaling -+ @test prod(ProdAnalytical(), MvNormalMeanPrecision([-1, -1], [2 0; 0 2]), MvNormalMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ MvNormalWeightedMeanPrecision([0, 2], [4, 6]) -+ @test prod(ProdAnalytical(), MvNormalMeanPrecision([-1, -1], [2, 2]), MvNormalMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ MvNormalWeightedMeanPrecision([0, 2], [4, 6]) -+ @test prod(ProdAnalytical(), MvNormalMeanPrecision([-1, -1], 2 * I), MvNormalMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ MvNormalWeightedMeanPrecision([0, 2], [4, 6]) - end - - @testset "Primitive types conversion" begin -diff --git a/test/distributions/test_mv_normal_weighted_mean_precision.jl b/test/distributions/test_mv_normal_weighted_mean_precision.jl -index 681adf6d..ee28cd68 100644 ---- a/test/distributions/test_mv_normal_weighted_mean_precision.jl -+++ b/test/distributions/test_mv_normal_weighted_mean_precision.jl -@@ -14,6 +14,13 @@ using Distributions - @test MvNormalWeightedMeanPrecision([1, 2]) == MvNormalWeightedMeanPrecision([1.0, 2.0], [1.0, 1.0]) - @test MvNormalWeightedMeanPrecision([1.0f0, 2.0f0]) == MvNormalWeightedMeanPrecision([1.0f0, 2.0f0], [1.0f0, 1.0f0]) - -+ # uniformscaling -+ @test MvNormalWeightedMeanPrecision([1, 2], I) == MvNormalWeightedMeanPrecision([1, 2], Diagonal([1, 1])) -+ @test MvNormalWeightedMeanPrecision([1, 2], 6 * I) == MvNormalWeightedMeanPrecision([1, 2], Diagonal([6, 6])) -+ @test MvNormalWeightedMeanPrecision([1.0, 2.0], I) == MvNormalWeightedMeanPrecision([1.0, 2.0], Diagonal([1.0, 1.0])) -+ @test MvNormalWeightedMeanPrecision([1.0, 2.0], 6 * I) == MvNormalWeightedMeanPrecision([1.0, 2.0], Diagonal([6.0, 6.0])) -+ @test MvNormalWeightedMeanPrecision([1, 2], 6.0 * I) == MvNormalWeightedMeanPrecision([1.0, 2.0], Diagonal([6.0, 6.0])) -+ - @test eltype(MvNormalWeightedMeanPrecision([1.0, 1.0])) === Float64 - @test eltype(MvNormalWeightedMeanPrecision([1.0, 1.0], [1.0, 1.0])) === Float64 - @test eltype(MvNormalWeightedMeanPrecision([1, 1])) === Float64 -@@ -91,6 +98,14 @@ using Distributions - dist = MvNormalWeightedMeanPrecision(xi, Λ) - - @test prod(ProdAnalytical(), dist, dist) ≈ MvNormalWeightedMeanPrecision([0.40, 6.00, 8.00], [3.00 -0.20 0.20; -0.20 3.60 0.00; 0.20 0.00 7.00]) -+ -+ # diagonal covariance matrix/uniformscaling -+ @test prod(ProdAnalytical(), MvNormalWeightedMeanPrecision([-1, -1], [2 0; 0 2]), MvNormalWeightedMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ -+ MvNormalWeightedMeanPrecision([0, 0], [4, 6]) -+ @test prod(ProdAnalytical(), MvNormalWeightedMeanPrecision([-1, -1], [2, 2]), MvNormalWeightedMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ -+ MvNormalWeightedMeanPrecision([0, 0], [4, 6]) -+ @test prod(ProdAnalytical(), MvNormalWeightedMeanPrecision([-1, -1], 2 * I), MvNormalWeightedMeanPrecision([1, 1], Diagonal([2, 4]))) ≈ -+ MvNormalWeightedMeanPrecision([0, 0], [4, 6]) - end - - @testset "Primitive types conversion" begin -diff --git a/test/distributions/test_normal_mean_precision.jl b/test/distributions/test_normal_mean_precision.jl -index 3db706ac..006ef088 100644 ---- a/test/distributions/test_normal_mean_precision.jl -+++ b/test/distributions/test_normal_mean_precision.jl -@@ -3,6 +3,8 @@ module NormalMeanPrecisionTest - using Test - using ReactiveMP - -+using LinearAlgebra: I -+ - @testset "NormalMeanPrecision" begin - @testset "Constructor" begin - @test NormalMeanPrecision <: NormalDistributionsFamily -@@ -20,6 +22,13 @@ using ReactiveMP - @test NormalMeanPrecision(1.0f0, 2) == NormalMeanPrecision{Float32}(1.0f0, 2.0f0) - @test NormalMeanPrecision(1.0f0, 2.0) == NormalMeanPrecision{Float64}(1.0, 2.0) - -+ # uniformscaling -+ @test NormalMeanPrecision(2, I) == NormalMeanPrecision(2, 1) -+ @test NormalMeanPrecision(2, 6 * I) == NormalMeanPrecision(2, 6) -+ @test NormalMeanPrecision(2.0, I) == NormalMeanPrecision(2.0, 1.0) -+ @test NormalMeanPrecision(2.0, 6 * I) == NormalMeanPrecision(2.0, 6.0) -+ @test NormalMeanPrecision(2, 6.0 * I) == NormalMeanPrecision(2.0, 6.0) -+ - @test eltype(NormalMeanPrecision()) === Float64 - @test eltype(NormalMeanPrecision(0.0)) === Float64 - @test eltype(NormalMeanPrecision(0.0, 1.0)) === Float64 -diff --git a/test/distributions/test_normal_mean_variance.jl b/test/distributions/test_normal_mean_variance.jl -index 27b27389..a17ff8a4 100644 ---- a/test/distributions/test_normal_mean_variance.jl -+++ b/test/distributions/test_normal_mean_variance.jl -@@ -3,6 +3,8 @@ module NormalMeanVarianceTest - using Test - using ReactiveMP - -+using LinearAlgebra: I -+ - @testset "NormalMeanVariance" begin - @testset "Constructor" begin - @test NormalMeanVariance <: NormalDistributionsFamily -@@ -20,6 +22,13 @@ using ReactiveMP - @test NormalMeanVariance(1.0f0, 2) == NormalMeanVariance{Float32}(1.0f0, 2.0f0) - @test NormalMeanVariance(1.0f0, 2.0) == NormalMeanVariance{Float64}(1.0, 2.0) - -+ # uniformscaling -+ @test NormalMeanVariance(2, I) == NormalMeanVariance(2, 1) -+ @test NormalMeanVariance(2, 6 * I) == NormalMeanVariance(2, 6) -+ @test NormalMeanVariance(2.0, I) == NormalMeanVariance(2.0, 1.0) -+ @test NormalMeanVariance(2.0, 6 * I) == NormalMeanVariance(2.0, 6.0) -+ @test NormalMeanVariance(2, 6.0 * I) == NormalMeanVariance(2.0, 6.0) -+ - @test eltype(NormalMeanVariance()) === Float64 - @test eltype(NormalMeanVariance(0.0)) === Float64 - @test eltype(NormalMeanVariance(0.0, 1.0)) === Float64 -diff --git a/test/distributions/test_normal_weighted_mean_precision.jl b/test/distributions/test_normal_weighted_mean_precision.jl -index ace5ebfd..5a90d97e 100644 ---- a/test/distributions/test_normal_weighted_mean_precision.jl -+++ b/test/distributions/test_normal_weighted_mean_precision.jl -@@ -3,6 +3,8 @@ module NormalWeightedMeanPrecisionTest - using Test - using ReactiveMP - -+using LinearAlgebra: I -+ - @testset "NormalWeightedMeanPrecision" begin - @testset "Constructor" begin - @test NormalWeightedMeanPrecision <: NormalDistributionsFamily -@@ -19,6 +21,13 @@ using ReactiveMP - @test NormalWeightedMeanPrecision(1.0f0, 2.0f0) == NormalWeightedMeanPrecision{Float32}(1.0f0, 2.0f0) - @test NormalWeightedMeanPrecision(1.0f0, 2.0) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) - -+ # uniformscaling -+ @test NormalWeightedMeanPrecision(2, I) == NormalWeightedMeanPrecision(2, 1) -+ @test NormalWeightedMeanPrecision(2, 6 * I) == NormalWeightedMeanPrecision(2, 6) -+ @test NormalWeightedMeanPrecision(2.0, I) == NormalWeightedMeanPrecision(2.0, 1.0) -+ @test NormalWeightedMeanPrecision(2.0, 6 * I) == NormalWeightedMeanPrecision(2.0, 6.0) -+ @test NormalWeightedMeanPrecision(2, 6.0 * I) == NormalWeightedMeanPrecision(2.0, 6.0) -+ - @test eltype(NormalWeightedMeanPrecision()) === Float64 - @test eltype(NormalWeightedMeanPrecision(0.0)) === Float64 - @test eltype(NormalWeightedMeanPrecision(0.0, 1.0)) === Float64 -diff --git a/test/distributions/test_pointmass.jl b/test/distributions/test_pointmass.jl -index 8f45a481..1a2e926b 100644 ---- a/test/distributions/test_pointmass.jl -+++ b/test/distributions/test_pointmass.jl -@@ -5,6 +5,7 @@ using ReactiveMP - using Distributions - using Random - using SpecialFunctions -+using LinearAlgebra: UniformScaling, I - - import ReactiveMP: CountingReal, tiny, huge - import ReactiveMP.MacroHelpers: @test_inferred -@@ -163,6 +164,47 @@ import ReactiveMP: xtlog, mirrorlog - @test @test_inferred(AbstractMatrix{T}, mean(loggamma, dist)) == loggamma.(matrix) - end - end -+ -+ @testset "UniformScaling-based PointMass" begin -+ for T in (Float16, Float32, Float64, BigFloat) -+ matrix = convert(T, 5) * I -+ dist = PointMass(matrix) -+ -+ @test variate_form(dist) === Matrixvariate -+ @test dist[2, 1] == zero(T) -+ @test dist[3, 1] == zero(T) -+ @test dist[3, 3] === matrix[3, 3] -+ -+ @test pdf(dist, matrix) == one(T) -+ @test pdf(dist, matrix + convert(T, tiny) * I) == zero(T) -+ @test pdf(dist, matrix - convert(T, tiny) * I) == zero(T) -+ -+ @test logpdf(dist, matrix) == zero(T) -+ @test logpdf(dist, matrix + convert(T, tiny) * I) == convert(T, -Inf) -+ @test logpdf(dist, matrix - convert(T, tiny) * I) == convert(T, -Inf) -+ -+ @test_throws MethodError insupport(dist, one(T)) -+ @test_throws MethodError insupport(dist, ones(T, 2)) -+ @test_throws MethodError pdf(dist, one(T)) -+ @test_throws MethodError pdf(dist, ones(T, 2)) -+ @test_throws MethodError logpdf(dist, one(T)) -+ @test_throws MethodError logpdf(dist, ones(T, 2)) -+ -+ @test (@inferred entropy(dist)) == CountingReal(eltype(dist), -1) -+ -+ @test mean(dist) == matrix -+ @test mode(dist) == matrix -+ @test var(dist) == zero(T) * I -+ @test std(dist) == zero(T) * I -+ -+ @test_throws ErrorException cov(dist) -+ @test_throws ErrorException precision(dist) -+ -+ @test_throws ErrorException probvec(dist) -+ @test mean(inv, dist) ≈ inv(matrix) -+ @test mean(cholinv, dist) ≈ inv(matrix) -+ end -+ end - end - - end -diff --git a/test/rules/bernoulli/test_marginals.jl b/test/rules/bernoulli/test_marginals.jl -index da6be1e2..d1b835f8 100644 ---- a/test/rules/bernoulli/test_marginals.jl -+++ b/test/rules/bernoulli/test_marginals.jl -@@ -13,5 +13,12 @@ import ReactiveMP: @test_marginalrules - (input = (m_out = PointMass(0.0), m_p = Beta(1.0, 2.0)), output = (out = PointMass(0.0), p = Beta(1.0, 3.0))) - ] - end -+ @testset "out_p: (m_out::Bernoulli, m_p::PointMass)" begin -+ @test_marginalrules [with_float_conversions = true] Bernoulli(:out_p) [ -+ (input = (m_out = Bernoulli(0.8), m_p = PointMass(1.0)), output = (out = Bernoulli(1.0), p = PointMass(1.0))), -+ (input = (m_out = Bernoulli(0.2), m_p = PointMass(1.0)), output = (out = Bernoulli(1.0), p = PointMass(1.0))), -+ (input = (m_out = Bernoulli(0.2), m_p = PointMass(0.0)), output = (out = Bernoulli(0.0), p = PointMass(0.0))) -+ ] -+ end - end - end -diff --git a/test/rules/bernoulli/test_p.jl b/test/rules/bernoulli/test_p.jl -index c6c34c34..f5b08bf7 100644 ---- a/test/rules/bernoulli/test_p.jl -+++ b/test/rules/bernoulli/test_p.jl -@@ -19,7 +19,7 @@ import ReactiveMP: @test_rules - end - - @testset "Variational Message Passing: (q_out::DiscreteNonParametric)" begin -- # `with_falot_conversions = false` here is because apparently -+ # `with_float_conversions = false` here is because apparently - # BigFloat(0.7) + BigFloat(0.3) != BigFloat(1.0) - @test_rules [with_float_conversions = false] Bernoulli(:p, Marginalisation) [ - (input = (q_out = Categorical([0.0, 1.0]),), output = Beta(2.0, 1.0)), (input = (q_out = Categorical([0.7, 0.3]),), output = Beta(13 / 10, 17 / 10)) -diff --git a/test/rules/categorical/test_marginals.jl b/test/rules/categorical/test_marginals.jl -new file mode 100644 -index 00000000..43363b5c ---- /dev/null -+++ b/test/rules/categorical/test_marginals.jl -@@ -0,0 +1,25 @@ -+module RulesCategoricalMarginalsTest -+ -+using Test -+using ReactiveMP -+using Random -+using LinearAlgebra -+import ReactiveMP: @test_marginalrules -+ -+@testset "marginalrules:Categorical" begin -+ @testset "out_p: (m_out::PointMass, m_p::Dirichlet)" begin -+ @test_marginalrules [with_float_conversions = true] Categorical(:out_p) [ -+ (input = (m_out = PointMass([0.0, 1.0]), m_p = Dirichlet([2.0, 1.0])), output = (out = PointMass([0.0, 1.0]), p = Dirichlet([2.0, 2.0]))), -+ (input = (m_out = PointMass([0.0, 1.0]), m_p = Dirichlet([4.0, 2.0])), output = (out = PointMass([0.0, 1.0]), p = Dirichlet([4.0, 3.0]))), -+ (input = (m_out = PointMass([1.0, 0.0]), m_p = Dirichlet([1.0, 2.0])), output = (out = PointMass([1.0, 0.0]), p = Dirichlet([2.0, 2.0]))) -+ ] -+ end -+ @testset "out_p: (m_out::Categorical, m_p::PointMass)" begin -+ @test_marginalrules [with_float_conversions = false] Categorical(:out_p) [ -+ (input = (m_out = Categorical([0.2, 0.8]), m_p = PointMass([0.0, 1.0])), output = (out = Categorical(normalize([tiny, 0.8], 1)), p = PointMass([0.0, 1.0]))), -+ (input = (m_out = Categorical([0.8, 0.2]), m_p = PointMass([0.0, 1.0])), output = (out = Categorical(normalize([tiny, 0.2], 1)), p = PointMass([0.0, 1.0]))), -+ (input = (m_out = Categorical([0.8, 0.2]), m_p = PointMass([1.0, 0.0])), output = (out = Categorical(normalize([0.8, tiny], 1)), p = PointMass([1.0, 0.0]))) -+ ] -+ end -+end -+end -diff --git a/test/rules/categorical/test_out.jl b/test/rules/categorical/test_out.jl -new file mode 100644 -index 00000000..929ffbc2 ---- /dev/null -+++ b/test/rules/categorical/test_out.jl -@@ -0,0 +1,27 @@ -+module RulesCategoricalOutTest -+ -+using Test -+using ReactiveMP -+using Random -+import ReactiveMP: @test_rules -+ -+@testset "rules:Categorical:out" begin -+ @testset "Belief Propagation: (m_p::PointMass)" begin -+ @test_rules [with_float_conversions = false] Categorical(:out, Marginalisation) [ -+ (input = (m_p = PointMass([0.0, 1.0]),), output = Categorical([0.0, 1.0])), (input = (m_p = PointMass([0.8, 0.2]),), output = Categorical([0.8, 0.2])) -+ ] -+ end -+ -+ @testset "Variational Message Passing: (q_p::PointMass)" begin -+ @test_rules [with_float_conversions = false] Categorical(:out, Marginalisation) [ -+ (input = (q_p = PointMass([0.0, 1.0]),), output = Categorical([0.0, 1.0])), (input = (q_p = PointMass([0.7, 0.3]),), output = Categorical([0.7, 0.3])) -+ ] -+ end -+ -+ @testset "Variational Message Passing: (q_p::Dirichlet)" begin -+ @test_rules [with_float_conversions = false] Categorical(:out, Marginalisation) [ -+ (input = (q_p = Dirichlet([1.0, 1.0]),), output = Categorical([0.5, 0.5])), (input = (q_p = Dirichlet([0.2, 0.2]),), output = Categorical([0.5, 0.5])) -+ ] -+ end -+end -+end -diff --git a/test/rules/categorical/test_p.jl b/test/rules/categorical/test_p.jl -new file mode 100644 -index 00000000..55e396ca ---- /dev/null -+++ b/test/rules/categorical/test_p.jl -@@ -0,0 +1,21 @@ -+module RulesCategoricalPTest -+ -+using Test -+using ReactiveMP -+using Random -+import ReactiveMP: @test_rules -+ -+@testset "rules:Categorical:p" begin -+ @testset "Variational Message Passing: (q_out::PointMass)" begin -+ @test_rules [with_float_conversions = true] Categorical(:p, Marginalisation) [ -+ (input = (q_out = PointMass([0.0, 1.0]),), output = Dirichlet([1.0, 2.0])), (input = (q_out = PointMass([0.8, 0.2]),), output = Dirichlet([9 / 5, 12 / 10])) -+ ] -+ end -+ -+ @testset "Variational Message Passing: (q_out::Categorical)" begin -+ @test_rules [with_float_conversions = false] Categorical(:p, Marginalisation) [ -+ (input = (q_out = Categorical([0.0, 1.0]),), output = Dirichlet([1.0, 2.0])), (input = (q_out = Categorical([0.7, 0.3]),), output = Dirichlet([17 / 10, 13 / 10])) -+ ] -+ end -+end -+end -diff --git a/test/rules/dirichlet/test_marginals.jl b/test/rules/dirichlet/test_marginals.jl -new file mode 100644 -index 00000000..51682227 ---- /dev/null -+++ b/test/rules/dirichlet/test_marginals.jl -@@ -0,0 +1,17 @@ -+module RulesDirichletMarginalsTest -+ -+using Test -+using ReactiveMP -+using Random -+import ReactiveMP: @test_marginalrules -+ -+@testset "marginalrules:Dirichlet" begin -+ @testset "out_a: (m_out::Dirichlet, m_a::PointMass)" begin -+ @test_marginalrules [with_float_conversions = true] Dirichlet(:out_a) [ -+ (input = (m_out = Dirichlet([1.0, 2.0]), m_a = PointMass([0.2, 1.0])), output = (out = Dirichlet([0.2, 2.0]), a = PointMass([0.2, 1.0]))), -+ (input = (m_out = Dirichlet([2.0, 2.0]), m_a = PointMass([2.0, 0.5])), output = (out = Dirichlet([3.0, 1.5]), a = PointMass([2.0, 0.5]))), -+ (input = (m_out = Dirichlet([2.0, 3.0]), m_a = PointMass([3.0, 1.0])), output = (out = Dirichlet([4.0, 3.0]), a = PointMass([3.0, 1.0]))) -+ ] -+ end -+end -+end -diff --git a/test/rules/dirichlet/test_out.jl b/test/rules/dirichlet/test_out.jl -new file mode 100644 -index 00000000..874a5885 ---- /dev/null -+++ b/test/rules/dirichlet/test_out.jl -@@ -0,0 +1,25 @@ -+module RulesDirichletOutTest -+ -+using Test -+using ReactiveMP -+using Random -+import ReactiveMP: @test_rules -+ -+@testset "rules:Dirichlet:out" begin -+ @testset "Belief Propagation: (m_a::PointMass)" begin -+ @test_rules [with_float_conversions = true] Dirichlet(:out, Marginalisation) [ -+ (input = (m_a = PointMass([0.2, 1.0]),), output = Dirichlet([0.2, 1.0])), -+ (input = (m_a = PointMass([2.0, 0.5]),), output = Dirichlet([2.0, 0.5])), -+ (input = (m_a = PointMass([3.0, 1.0]),), output = Dirichlet([3.0, 1.0])) -+ ] -+ end -+ -+ @testset "Variational Message Passing: (q_a::PointMass)" begin -+ @test_rules [with_float_conversions = true] Dirichlet(:out, Marginalisation) [ -+ (input = (q_a = PointMass([0.2, 1.0]),), output = Dirichlet([0.2, 1.0])), -+ (input = (q_a = PointMass([2.0, 0.5]),), output = Dirichlet([2.0, 0.5])), -+ (input = (q_a = PointMass([3.0, 1.0]),), output = Dirichlet([3.0, 1.0])) -+ ] -+ end -+end -+end -diff --git a/test/runtests.jl b/test/runtests.jl -index 2e50971e..0c2fa1e2 100644 ---- a/test/runtests.jl -+++ b/test/runtests.jl -@@ -312,6 +312,10 @@ end - addtests(testrunner, "rules/beta/test_out.jl") - addtests(testrunner, "rules/beta/test_marginals.jl") - -+ addtests(testrunner, "rules/categorical/test_out.jl") -+ addtests(testrunner, "rules/categorical/test_p.jl") -+ addtests(testrunner, "rules/categorical/test_marginals.jl") -+ - addtests(testrunner, "rules/delta/unscented/test_out.jl") - addtests(testrunner, "rules/delta/unscented/test_in.jl") - addtests(testrunner, "rules/delta/unscented/test_marginals.jl") -@@ -324,6 +328,9 @@ end - addtests(testrunner, "rules/delta/cvi/test_marginals.jl") - addtests(testrunner, "rules/delta/cvi/test_out.jl") - -+ addtests(testrunner, "rules/dirichlet/test_marginals.jl") -+ addtests(testrunner, "rules/dirichlet/test_out.jl") -+ - addtests(testrunner, "rules/dot_product/test_out.jl") - addtests(testrunner, "rules/dot_product/test_in1.jl") - addtests(testrunner, "rules/dot_product/test_in2.jl") -diff --git a/test/variables/test_constant.jl b/test/variables/test_constant.jl -index 8453b723..0e349414 100644 ---- a/test/variables/test_constant.jl -+++ b/test/variables/test_constant.jl -@@ -4,13 +4,15 @@ using Test - using ReactiveMP - using Rocket - -+using LinearAlgebra: I -+ - import ReactiveMP: collection_type, VariableIndividual, VariableVector, VariableArray, linear_index - import ReactiveMP: getconst, proxy_variables - import ReactiveMP: israndom, isproxy - - @testset "ConstVariable" begin - @testset "Simple creation" begin -- for sym in (:x, :y, :z), value in (1.0, 1.0, "asd", [1.0, 1.0], [1.0 0.0; 0.0 1.0], (x) -> 1) -+ for sym in (:x, :y, :z), value in (1.0, 1.0, "asd", I, 0.3 * I, [1.0, 1.0], [1.0 0.0; 0.0 1.0], (x) -> 1) - v = constvar(sym, value) - - @test !israndom(v) -diff --git a/test/variables/test_data.jl b/test/variables/test_data.jl -index 683cf1a1..42b517a5 100644 ---- a/test/variables/test_data.jl -+++ b/test/variables/test_data.jl -@@ -4,10 +4,10 @@ using Test - using ReactiveMP - using Rocket - --import ReactiveMP: DataVariableCreationOptions -+import ReactiveMP: DataVariableCreationOptions, MessageObservable - import ReactiveMP: collection_type, VariableIndividual, VariableVector, VariableArray, linear_index - import ReactiveMP: getconst, proxy_variables --import ReactiveMP: israndom, isproxy, allows_missings -+import ReactiveMP: israndom, isproxy, isused, isconnected, setmessagein!, allows_missings - - @testset "DataVariable" begin - @testset "Simple creation" begin -@@ -44,10 +44,20 @@ import ReactiveMP: israndom, isproxy, allows_missings - @test !israndom(variable) - @test eltype(variable) === T - @test name(variable) === sym -+ @test allows_missings(variable) === allow_missings - @test collection_type(variable) isa VariableIndividual - @test proxy_variables(variable) === nothing - @test !isproxy(variable) -- @test allows_missings(variable) === allow_missings -+ @test !isused(variable) -+ @test !isconnected(variable) -+ -+ setmessagein!(variable, 1, MessageObservable()) -+ -+ @test isused(variable) -+ @test isconnected(variable) -+ -+ # `100` could a valid index, but messages should be initialized in order, previous was `1` -+ @test_throws ErrorException setmessagein!(variable, 100, MessageObservable()) - end - - for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), n in (10, 20), allow_missings in (true, false) -@@ -59,22 +69,24 @@ import ReactiveMP: israndom, isproxy, allows_missings - @test variables isa Vector - @test all(v -> !israndom(v), variables) - @test all(v -> name(v) === sym, variables) -+ @test all(v -> allows_missings(v) === allow_missings, variables) - @test all(v -> collection_type(v) isa VariableVector, variables) - @test all(t -> linear_index(collection_type(t[2])) === t[1], enumerate(variables)) - @test all(v -> eltype(v) === T, variables) - @test !isproxy(variables) - @test all(v -> !isproxy(v), variables) -+ @test all(v -> !isused(v), variables) -+ @test all(v -> !isconnected(v), variables) - @test test_updates(variables, T, (n,)) - -- @test all(v -> allows_missings(v) === allow_missings, variables) -- if allow_missings -- test_updates(variables, Missing, (n,)) -- end -+ foreach(v -> setmessagein!(v, 1, MessageObservable()), variables) -+ -+ @test all(v -> isused(v), variables) -+ @test all(v -> isconnected(v), variables) - end - -- for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), l in (10, 20), r in (10, 20), allow_missings in (true, false) -- options = DataVariableCreationOptions(T, nothing, Val(allow_missings)) -- for variables in (datavar(options, sym, T, l, r), datavar(options, sym, T, (l, r))) -+ for sym in (:x, :y, :z), T in (Float64, Int64, Vector{Float64}), l in (10, 20), r in (10, 20) -+ for variables in (datavar(sym, T, l, r), datavar(sym, T, (l, r))) - @test !israndom(variables) - @test size(variables) === (l, r) - @test length(variables) === l * r -@@ -86,12 +98,13 @@ import ReactiveMP: israndom, isproxy, allows_missings - @test all(v -> eltype(v) === T, variables) - @test !isproxy(variables) - @test all(v -> !isproxy(v), variables) -+ @test all(v -> !isused(v), variables) - @test test_updates(variables, T, (l, r)) - -- @test all(v -> allows_missings(v) === allow_missings, variables) -- if allow_missings -- test_updates(variables, Missing, (l, r)) -- end -+ foreach(v -> setmessagein!(v, 1, MessageObservable()), variables) -+ -+ @test all(v -> isused(v), variables) -+ @test all(v -> isconnected(v), variables) - end - end - end From e0edea1bb13b778b871d4a969dd7560b04f04a88 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 8 Sep 2023 14:36:30 +0200 Subject: [PATCH 43/48] Remove MV autoregressive node --- src/nodes/mv_autoregressive.jl | 149 ----------------------- src/rules/mv_autoregressive/a.jl | 41 ------- src/rules/mv_autoregressive/lambda.jl | 47 ------- src/rules/mv_autoregressive/marginals.jl | 41 ------- src/rules/mv_autoregressive/x.jl | 44 ------- src/rules/mv_autoregressive/y.jl | 31 ----- src/rules/prototypes.jl | 6 - 7 files changed, 359 deletions(-) delete mode 100644 src/nodes/mv_autoregressive.jl delete mode 100644 src/rules/mv_autoregressive/a.jl delete mode 100644 src/rules/mv_autoregressive/lambda.jl delete mode 100644 src/rules/mv_autoregressive/marginals.jl delete mode 100644 src/rules/mv_autoregressive/x.jl delete mode 100644 src/rules/mv_autoregressive/y.jl diff --git a/src/nodes/mv_autoregressive.jl b/src/nodes/mv_autoregressive.jl deleted file mode 100644 index f0c969362..000000000 --- a/src/nodes/mv_autoregressive.jl +++ /dev/null @@ -1,149 +0,0 @@ -export MAR, MvAutoregressive, MARMeta, mar_transition, mar_shift - -import LazyArrays, BlockArrays -import StatsFuns: log2π - -struct MAR end - -const MvAutoregressive = MAR - -struct MARMeta - order :: Int # order (lag) of MAR - ds :: Int # dimensionality of MAR process, i.e., the number of correlated AR processes - Fs :: Vector{<:AbstractMatrix} # masks - es :: Vector{<:AbstractVector} # unit vectors - - function MARMeta(order, ds = 2) - @assert ds >= 2 "ds parameter should be > 1. Use AR node if ds = 1" - Fs = [mask_mar(order, ds, i) for i in 1:ds] - es = [uvector(order * ds, i) for i in 1:ds] - return new(order, ds, Fs, es) - end -end - -getorder(meta::MARMeta) = meta.order -getdimensionality(meta::MARMeta) = meta.ds -getmasks(meta::MARMeta) = meta.Fs -getunits(meta::MARMeta) = meta.es - -@node MAR Stochastic [y, x, a, Λ] - -default_meta(::Type{MAR}) = error("MvAutoregressive node requires meta flag explicitly specified") - -@average_energy MAR (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::MARMeta) = begin - ma, Va = mean_cov(q_a) - myx, Vyx = mean_cov(q_y_x) - mΛ = mean(q_Λ) - - order, ds = getorder(meta), getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) - - dim = order * ds - F = Multivariate - - n = div(ndims(q_y_x), 2) - - ma, Va = mean_cov(q_a) - mA = mar_companion_matrix(ma, meta)[1:ds, 1:dim] - - mx, Vx = ar_slice(F, myx, (dim + 1):(2dim)), ar_slice(F, Vyx, (dim + 1):(2dim), (dim + 1):(2dim)) - my1, Vy1 = myx[1:ds], Vyx[1:ds, 1:ds] - Vy1x = ar_slice(F, Vyx, 1:ds, (dim + 1):(2dim)) - - g₁ = my1' * mΛ * my1 + tr(Vy1 * mΛ) - g₂ = mx' * mA' * mΛ * my1 + tr(Vy1x * mA' * mΛ) - g₃ = g₂ - G = sum(sum(es[i]' * mΛ * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:ds) for j in 1:ds) - g₄ = mx' * G * mx + tr(Vx * G) - AE = n / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ - g₂ - g₃ + g₄) - - if order > 1 - AE += entropy(q_y_x) - idc = LazyArrays.Vcat(1:ds, (dim + 1):(2dim)) - myx_n = view(myx, idc) - Vyx_n = view(Vyx, idc, idc) - q_y_x = MvNormalMeanCovariance(myx_n, Vyx_n) - AE -= entropy(q_y_x) - end - - return AE -end - -@average_energy MAR ( - q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::MARMeta -) = begin - ma, Va = mean_cov(q_a) - my, Vy = mean_cov(q_y) - mx, Vx = mean_cov(q_y) - mΛ = mean(q_Λ) - - order, ds = getorder(meta), getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) - - dim = order * ds - F = Multivariate - - ma, Va = mean_cov(q_a) - mA = mar_companion_matrix(ma, meta)[1:ds, 1:dim] - - my1, Vy1 = my[1:ds], Vy[1:ds, 1:ds] - - g₁ = my1' * mΛ * my1 + tr(Vy1 * mΛ) - g₂ = -mx' * mA' * mΛ * my1 - g₃ = -g₂ - G = sum(sum(es[i]' * mΛ * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:ds) for j in 1:ds) - g₄ = mx' * G * mx + tr(Vx * G) - AE = dim / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ + g₂ + g₃ + g₄) - - if order > 1 - AE += entropy(q_y) - q_y = MvNormalMeanCovariance(my1, Vy1) - AE -= entropy(q_y) - end - - return AE -end - -# Helpers for AR rules -function mask_mar(order, dimension, index) - F = zeros(dimension * order, dimension * dimension * order) - - @inbounds for k in 1:order - start_col = (k - 1) * dimension^2 + (index - 1) * dimension + 1 - end_col = start_col + dimension - 1 - start_row = (k - 1) * dimension + 1 - end_row = start_row + dimension - 1 - F[start_row:end_row, start_col:end_col] = I(dimension) - end - - return F -end - -function mar_transition(order, Λ) - dim = size(Λ, 1) - W = diageye(dim * order) - W[1:dim, 1:dim] = Λ - return W -end - -function mar_shift(order, ds) - dim = order * ds - S = diageye(dim) - S = circshift(S, ds) - S[:, (end - ds + 1):end] .= 0 - return S -end - -function uvector(dim, pos = 1) - u = zeros(dim) - u[pos] = 1 - return dim == 1 ? u[pos] : u -end - -function mar_companion_matrix(a, meta::MARMeta) - order, ds = getorder(meta), getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) - - L = mar_shift(order, ds) .+ sum(es[i] * a' * Fs[i]' for i in 1:ds) - return L -end \ No newline at end of file diff --git a/src/rules/mv_autoregressive/a.jl b/src/rules/mv_autoregressive/a.jl deleted file mode 100644 index 41b488646..000000000 --- a/src/rules/mv_autoregressive/a.jl +++ /dev/null @@ -1,41 +0,0 @@ - -@rule MAR(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin - order, ds = getorder(meta), getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) - dim = order * ds - - m, V = mean_cov(q_y_x) - - F = Multivariate - - my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) - mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) - Vyx = ar_slice(F, V, 1:dim, (dim + 1):(2dim)) - - mΛ = mean(q_Λ) - mW = mar_transition(order, mΛ) - - # NOTE: prove that sum(Fs[i]'*((mx*mx'+Vx')*S')*mW*es[i] for i in 1:ds) == 0.0 - D = sum(sum(es[i]' * mW * es[j] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:ds) for j in 1:ds) - z = sum(Fs[i]' * (mx * my' + Vyx') * mW * es[i] for i in 1:ds) - - return MvNormalWeightedMeanPrecision(z, D) -end - -@rule MAR(:a, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin - order, ds = getorder(meta), getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) - dim = order * ds - - my, Vy = mean_cov(q_y) - mx, Vx = mean_cov(q_x) - mΛ = mean(q_Λ) - - mW = mar_transition(order, mΛ) - S = mar_shift(order, ds) - - D = sum(sum(es[j]' * mW * es[i] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:ds) for j in 1:ds) - z = sum(Fs[i]' * ((mx * mx' + Vx') * S' + mx * my') * mW * es[i] for i in 1:ds) - - return MvNormalWeightedMeanPrecision(z, D) -end diff --git a/src/rules/mv_autoregressive/lambda.jl b/src/rules/mv_autoregressive/lambda.jl deleted file mode 100644 index 5e859f315..000000000 --- a/src/rules/mv_autoregressive/lambda.jl +++ /dev/null @@ -1,47 +0,0 @@ -function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, ds, Fs, es) - G₁ = (my * my' + Vy)[1:ds, 1:ds] - G₂ = ((my * mx' + Vyx) * mA')[1:ds, 1:ds] - G₃ = transpose(G₂) - Ex_xx = mx * mx' + Vx - G₅ = sum(sum(es[i] * ma' * Fs[i]'Ex_xx * Fs[j] * ma * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - G₆ = sum(sum(es[i] * tr(Fs[i]' * Ex_xx * Fs[j] * Va) * es[j]' for i in 1:ds) for j in 1:ds)[1:ds, 1:ds] - Δ = G₁ - G₂ - G₃ + G₅ + G₆ -end - -@rule MAR(:Λ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::MARMeta) = begin - order, ds = getorder(meta), getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) - dim = order * ds - - F = Multivariate - - ma, Va = mean_cov(q_a) - - mA = mar_companion_matrix(ma, meta) - - m, V = mean_cov(q_y_x) - my, Vy = ar_slice(F, m, 1:dim), ar_slice(F, V, 1:dim, 1:dim) - mx, Vx = ar_slice(F, m, (dim + 1):(2dim)), ar_slice(F, V, (dim + 1):(2dim), (dim + 1):(2dim)) - Vyx = ar_slice(F, V, 1:dim, (dim + 1):(2dim)) - - Δ = compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, ds, Fs, es) - - return WishartMessage(ds + 2, Δ) -end - -@rule MAR(:Λ, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::MARMeta) = - begin - order, ds = getorder(meta), getdimensionality(meta) - F = Multivariate - dim = order * ds - - my, Vy = mean_cov(q_y) - mx, Vx = mean_cov(q_x) - ma, Va = mean_cov(q_a) - - mA = mar_companion_matrix(ma, meta) - - Δ = compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, ds, Fs, es) - - return WishartMessage(ds + 2, Δ) - end diff --git a/src/rules/mv_autoregressive/marginals.jl b/src/rules/mv_autoregressive/marginals.jl deleted file mode 100644 index 213a69061..000000000 --- a/src/rules/mv_autoregressive/marginals.jl +++ /dev/null @@ -1,41 +0,0 @@ - -@marginalrule MAR(:y_x) ( - m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta -) = begin - return ar_y_x_marginal(m_y, m_x, q_a, q_Λ, meta) -end - -function ar_y_x_marginal( - m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta -) - order, ds = getorder(meta), getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) - - ma, Va = mean_cov(q_a) - mΛ = mean(q_Λ) - - mA = mar_companion_matrix(ma, meta) - mW = mar_transition(getorder(meta), mΛ) - - b_my, b_Vy = mean_cov(m_y) - f_mx, f_Vx = mean_cov(m_x) - - inv_b_Vy = cholinv(b_Vy) - inv_f_Vx = cholinv(f_Vx) - - Ξ = inv_f_Vx + sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) - - W_11 = inv_b_Vy + mW - - # negate_inplace!(mW * mA) - W_12 = -(mW * mA) - - W_21 = -(mA' * mW) - - W_22 = Ξ + mA' * mW * mA - - W = [W_11 W_12; W_21 W_22] - ξ = [inv_b_Vy * b_my; inv_f_Vx * f_mx] - - return MvNormalWeightedMeanPrecision(ξ, W) -end diff --git a/src/rules/mv_autoregressive/x.jl b/src/rules/mv_autoregressive/x.jl deleted file mode 100644 index cafd80461..000000000 --- a/src/rules/mv_autoregressive/x.jl +++ /dev/null @@ -1,44 +0,0 @@ - -@rule MAR(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin - ma, Va = mean_cov(q_a) - my, Vy = mean_cov(m_y) - - mΛ = mean(q_Λ) - - order, ds = getorder(meta), getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) - dim = order * ds - - mA = mar_companion_matrix(ma, meta) - mW = mar_transition(getorder(meta), mΛ) - - Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) - - Σ₁ = Hermitian(pinv(mA) * (Vy) * pinv(mA') + pinv(mA' * mW * mA)) - - Ξ = (pinv(Σ₁) + Λ) - z = pinv(Σ₁) * pinv(mA) * my - - return MvNormalWeightedMeanPrecision(z, Ξ) -end - -@rule MAR(:x, Marginalisation) (q_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin - ma, Va = mean_cov(q_a) - my, Vy = mean_cov(q_y) - - mΛ = mean(q_Λ) - - order, ds = getorder(meta), getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) - - mA = mar_companion_matrix(ma, meta) - mW = mar_transition(getorder(meta), mΛ) - - Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) - Λ₀ = Hermitian(mA' * mW * mA) - - Ξ = Λ₀ + Λ - z = Λ₀ * pinv(mA) * my - - return MvNormalWeightedMeanPrecision(z, Ξ) -end diff --git a/src/rules/mv_autoregressive/y.jl b/src/rules/mv_autoregressive/y.jl deleted file mode 100644 index 6eede850b..000000000 --- a/src/rules/mv_autoregressive/y.jl +++ /dev/null @@ -1,31 +0,0 @@ -@rule MAR(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin - ma, Va = mean_cov(q_a) - mx, Wx = mean_invcov(m_x) - - mΛ = mean(q_Λ) - - order, ds = getorder(meta), getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) - - mA = mar_companion_matrix(ma, meta) - mW = mar_transition(getorder(meta), mΛ) - - Λ = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:ds) for j in 1:ds) - - Ξ = Λ + Wx - z = Wx * mx - - Vy = mA * inv(Ξ) * mA' + inv(mW) - my = mA * inv(Ξ) * z - - return MvNormalMeanCovariance(my, Vy) -end - -@rule MAR(:y, Marginalisation) (q_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::MARMeta) = begin - order, ds = getorder(meta), getdimensionality(meta) - - mA = mar_companion_matrix(mean(q_a), meta) - mW = mar_transition(getorder(meta), mean(q_Λ)) - - return MvNormalMeanPrecision(mA * mean(q_x), mW) -end diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl index be761e95e..d4d20ddd4 100644 --- a/src/rules/prototypes.jl +++ b/src/rules/prototypes.jl @@ -116,12 +116,6 @@ include("autoregressive/theta.jl") include("autoregressive/gamma.jl") include("autoregressive/marginals.jl") -include("mv_autoregressive/y.jl") -include("mv_autoregressive/x.jl") -include("mv_autoregressive/a.jl") -include("mv_autoregressive/lambda.jl") -include("mv_autoregressive/marginals.jl") - include("probit/marginals.jl") include("probit/in.jl") include("probit/out.jl") From 0643188e7f8dc72f302995cc8e6b5fd300652413 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 12 Sep 2023 16:33:33 +0200 Subject: [PATCH 44/48] Remove mv autoregressive from ReactiveMP.jl --- src/ReactiveMP.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index 404ea0556..ab389d85e 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -155,7 +155,6 @@ include("nodes/dot_product.jl") include("nodes/softdot.jl") include("nodes/transition.jl") include("nodes/autoregressive.jl") -include("nodes/mv_autoregressive.jl") include("nodes/bifm.jl") include("nodes/bifm_helper.jl") include("nodes/probit.jl") From a248d5b87889c177f09057278e490a9f7cdd4905 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 12 Sep 2023 16:39:20 +0200 Subject: [PATCH 45/48] Remove not needed exports --- src/nodes/autoregressive.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nodes/autoregressive.jl b/src/nodes/autoregressive.jl index 2d903972c..d96964bc5 100644 --- a/src/nodes/autoregressive.jl +++ b/src/nodes/autoregressive.jl @@ -1,4 +1,4 @@ -export AR, Autoregressive, ARsafe, ARunsafe, ARMeta, ar_unit, ar_slice +export AR, Autoregressive, ARsafe, ARunsafe, ARMeta import LazyArrays import Distributions: VariateForm From 5ddad7959b48a799e8e97a6b61acd32895c95981 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 12 Sep 2023 16:49:56 +0200 Subject: [PATCH 46/48] Remove BlockArrays --- Project.toml | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index e6e391f19..a84be8551 100644 --- a/Project.toml +++ b/Project.toml @@ -3,18 +3,7 @@ uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3" authors = ["Dmitry Bagaev ", "Albert Podusenko ", "Bart van Erp ", "Ismail Senoz "] version = "3.9.3" -[weakdeps] -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" - -[extensions] -ReactiveMPOptimisersExt = "Optimisers" -ReactiveMPZygoteExt = "Zygote" -ReactiveMPRequiresExt = "Requires" - [deps] -BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DomainIntegrals = "cc6bae93-f070-4015-88fd-838f9505a86c" @@ -39,6 +28,16 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +[weakdeps] +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[extensions] +ReactiveMPOptimisersExt = "Optimisers" +ReactiveMPRequiresExt = "Requires" +ReactiveMPZygoteExt = "Zygote" + [compat] DataStructures = "0.17, 0.18" Distributions = "0.24, 0.25" @@ -74,9 +73,9 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" From 0c94a128378e5f498d2a9e4ed3ee0af3f14f21c7 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Mon, 18 Sep 2023 11:48:19 +0200 Subject: [PATCH 47/48] Update src/variables/data.jl Co-authored-by: Bagaev Dmitry --- src/variables/data.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/variables/data.jl b/src/variables/data.jl index 6600cda8f..0f38d5364 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -194,7 +194,6 @@ _getprediction(datavar::DataVariable) = datavar.prediction _setprediction!(datavar::DataVariable, observable) = connect!(_getprediction(datavar), observable) _makeprediction(datavar::DataVariable) = collectLatest(AbstractMessage, Marginal, datavar.input_messages, marginal_prod_fn(datavar)) -# options here must implement at least `Rocket.getscheduler` function activate!(datavar::DataVariable, options) _setprediction!(datavar, _makeprediction(datavar)) From ad7b165c914acd0d71fd0b6498642772aca0b45a Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 18 Sep 2023 15:16:37 +0200 Subject: [PATCH 48/48] fix warning for predicted datavars --- src/constraints/specifications/constraints.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/constraints/specifications/constraints.jl b/src/constraints/specifications/constraints.jl index 1f5517b4f..b83ec6c5c 100644 --- a/src/constraints/specifications/constraints.jl +++ b/src/constraints/specifications/constraints.jl @@ -104,7 +104,10 @@ function activate!(constraints::ConstraintsSpecification, nodes::FactorNodesColl foreach(constraints.factorisation) do spec specnames = getnames(spec) foreach(specnames) do specname - if warn && (hasdatavar(variables, specname) || hasconstvar(variables, specname)) + if hasdatavar(variables, specname) && allows_missings(variables[specname]) + # skip, because it is fine to have a datavar in the factorization constraint, which allows missings + nothing + elseif warn && (hasdatavar(variables, specname) || hasconstvar(variables, specname)) @warn "Constraints specification has factorisation constraint for `q($(join(specnames, ", ")))`, but `$(specname)` is not a random variable. Data variables and constants in the model are forced to be factorized by default such that `q($(join(specnames, ", "))) = q($(specname))q(...)` . Use `warn = false` option during constraints specification to suppress this warning." elseif warn && !hasrandomvar(variables, specname) @warn "Constraints specification has factorisation constraint for `q($(join(specnames, ", ")))`, but variables collection has no random variable named `$(specname)`. Use `warn = false` option during constraints specification to suppress this warning."