From 68efb78a62b39e0b375b857b2698394ab833fb2c Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 25 Sep 2023 14:27:59 +0200 Subject: [PATCH 1/5] use MatrixCorrectionTools library --- Project.toml | 2 + src/ReactiveMP.jl | 2 +- src/helpers/algebra/correction.jl | 102 ----------------------- src/nodes/dot_product.jl | 4 +- src/nodes/multiplication.jl | 4 +- src/rules/dot_product/in1.jl | 2 +- src/rules/dot_product/in2.jl | 2 +- src/rules/dot_product/marginals.jl | 4 +- src/rules/dot_product/out.jl | 4 +- src/rules/multiplication/A.jl | 18 ++-- src/rules/multiplication/in.jl | 18 ++-- src/rules/multiplication/marginals.jl | 4 +- src/rules/multiplication/out.jl | 22 ++--- test/algebra/test_correction.jl | 62 -------------- test/rules/dot_product/test_in1.jl | 37 +++++--- test/rules/dot_product/test_in2.jl | 36 +++++--- test/rules/dot_product/test_marginals.jl | 47 ++++++----- test/rules/dot_product/test_out.jl | 25 +++--- test/rules/multiplication/test_A.jl | 8 +- test/rules/multiplication/test_in.jl | 8 +- test/rules/multiplication/test_out.jl | 8 +- test/runtests.jl | 1 - 22 files changed, 146 insertions(+), 274 deletions(-) delete mode 100644 src/helpers/algebra/correction.jl delete mode 100644 test/algebra/test_correction.jl diff --git a/Project.toml b/Project.toml index 6299e5d70..d31cb00ba 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +MatrixCorrectionTools = "41f81499-25de-46de-b591-c3cfc21e9eaf" Optim = "429524aa-4258-5aef-a3af-852621145aeb" PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -49,6 +50,7 @@ HCubature = "1.0.0" LazyArrays = "0.21, 0.22, 1" LoopVectorization = "0.12" MacroTools = "0.5" +MatrixCorrectionTools = "1.2.0" Optim = "1.0.0" Optimisers = "0.2" PositiveFactorizations = "0.2" diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index ab389d85e..0d8a6f644 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -3,6 +3,7 @@ module ReactiveMP # List global dependencies here using TinyHugeNumbers +using MatrixCorrectionTools: AbstractCorrectionStrategy, correction! # Reexport `tiny` and `huge` from the `TinyHugeNumbers` export tiny, huge @@ -16,7 +17,6 @@ include("score/counting.jl") include("helpers/algebra/cholesky.jl") include("helpers/algebra/companion_matrix.jl") -include("helpers/algebra/correction.jl") include("helpers/algebra/common.jl") include("helpers/algebra/permutation_matrix.jl") include("helpers/algebra/standard_basis_vector.jl") diff --git a/src/helpers/algebra/correction.jl b/src/helpers/algebra/correction.jl deleted file mode 100644 index bf3b272fd..000000000 --- a/src/helpers/algebra/correction.jl +++ /dev/null @@ -1,102 +0,0 @@ -export NoCorrection, TinyCorrection, FixedCorrection, ClampEigenValuesCorrection - -abstract type AbstractCorrection end - -# Correction regularization terms for matrices - -""" - correction!(strategy, matrix) - correction!(strategy, real) - -Modifies the `matrix` with a specified correction strategy. Matrix must be squared. -Also supports real values, with the same strategies. - -See also: [`NoCorrection`](@ref), [`TinyCorrection`](@ref) -""" -function correction! end - -""" - NoCorrection - -One of the correction strategies for `correction!` function. Does not modify matrix and returns an original one. - -See also: [`correction!`](@ref), [`TinyCorrection`](@ref) -""" -struct NoCorrection <: AbstractCorrection end - -correction!(::NoCorrection, value::Real) = value -correction!(::NoCorrection, matrix::AbstractMatrix) = matrix -correction!(::Nothing, something) = correction!(NoCorrection(), something) - -""" - TinyCorrection - -One of the correction strategies for `correction!` function. Adds `ReactiveMP.tiny` term to the `matrix`'s diagonal. - -See also: [`correction!`](@ref), [`NoCorrection`](@ref), [`FixedCorrection`](@ref), [`ClampEigenValuesCorrection`](@ref) -""" -struct TinyCorrection <: AbstractCorrection end - -correction!(::TinyCorrection, value::Real) = clamp(value, tiny, typemax(value)) - -function correction!(::TinyCorrection, matrix::AbstractMatrix) - s = size(matrix) - @assert length(s) == 2 && s[1] === s[2] - for i in 1:s[1] - @inbounds matrix[i, i] += tiny - end - return matrix -end - -""" - FixedCorrection - -One of the correction strategies for `correction!` function. Adds fixed `v` term to the `matrix`'s diagonal. - -# Arguments -- `v`: fixed value to add to the matrix diagonal - -See also: [`correction!`](@ref), [`NoCorrection`](@ref), [`TinyCorrection`](@ref), [`ClampEigenValuesCorrection`](@ref) -""" -struct FixedCorrection{T} <: AbstractCorrection - v::T -end - -correction!(correction::FixedCorrection, value::Real) = clamp(value, correction.v, Inf) - -function correction!(correction::FixedCorrection, matrix::AbstractMatrix) - s = size(matrix) - @assert length(s) == 2 && s[1] === s[2] - for i in 1:s[1] - @inbounds matrix[i, i] += correction.v - end - return matrix -end - -""" - ClampEigenValuesCorrection - -One of the correction strategies for `correction!` function. Clamps eigen values of matrix to be equal or greater than fixed `v` term. - -# Arguments -- `v`: fixed value used to clamp eigen values - -See also: [`correction!`](@ref), [`NoCorrection`](@ref), [`FixedCorrection`](@ref), [`TinyCorrection`](@ref) -""" -struct ClampEigenValuesCorrection{T} <: AbstractCorrection - v::T -end - -correction!(correction::ClampEigenValuesCorrection, value::Real) = clamp(value, correction.v, Inf) - -function correction!(correction::ClampEigenValuesCorrection, matrix::AbstractMatrix) - s = size(matrix) - @assert length(s) == 2 && s[1] === s[2] - - F = svd(matrix) - clamp!(F.S, correction.v, Inf) - R = lmul!(Diagonal(F.S), F.Vt) - M = mul!(matrix, F.U, R) - - return M -end diff --git a/src/nodes/dot_product.jl b/src/nodes/dot_product.jl index 9bd06f31b..9cd9f6c2d 100644 --- a/src/nodes/dot_product.jl +++ b/src/nodes/dot_product.jl @@ -4,5 +4,5 @@ import LinearAlgebra: dot @node typeof(dot) Deterministic [out, in1, in2] -# By default dot-product node uses TinyCorrection() strategy for precision matrix on `in1` and `in2` edges to ensure precision is always invertible -default_meta(::typeof(dot)) = TinyCorrection() +# By default dot-product node uses `MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny)` strategy for precision matrix on `in1` and `in2` edges to ensure precision is always invertible +default_meta(::typeof(dot)) = MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny) diff --git a/src/nodes/multiplication.jl b/src/nodes/multiplication.jl index 39597a9f1..df37f1c14 100644 --- a/src/nodes/multiplication.jl +++ b/src/nodes/multiplication.jl @@ -1,5 +1,5 @@ @node typeof(*) Deterministic [out, A, in] -# By default multiplication node uses TinyCorrection() strategy for precision matrix on `in` edge to ensure precision is always invertible -default_meta(::typeof(*)) = TinyCorrection() +# By default multiplication node uses `MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny)` strategy for precision matrix on `in` edge to ensure precision is always invertible +default_meta(::typeof(*)) = MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny) diff --git a/src/rules/dot_product/in1.jl b/src/rules/dot_product/in1.jl index 685d60125..aeeb19d06 100644 --- a/src/rules/dot_product/in1.jl +++ b/src/rules/dot_product/in1.jl @@ -1,4 +1,4 @@ -@rule typeof(dot)(:in1, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in2::PointMass, meta::AbstractCorrection) = begin +@rule typeof(dot)(:in1, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in2::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin return @call_rule typeof(dot)(:in2, Marginalisation) (m_out = m_out, m_in1 = m_in2, meta = meta) end diff --git a/src/rules/dot_product/in2.jl b/src/rules/dot_product/in2.jl index 6cd56e50f..fa36f8e5f 100644 --- a/src/rules/dot_product/in2.jl +++ b/src/rules/dot_product/in2.jl @@ -1,5 +1,5 @@ -@rule typeof(dot)(:in2, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in1::PointMass, meta::AbstractCorrection) = begin +@rule typeof(dot)(:in2, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in1::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin A = mean(m_in1) out_wmean, out_prec = weightedmean_precision(m_out) diff --git a/src/rules/dot_product/marginals.jl b/src/rules/dot_product/marginals.jl index b9f6358f6..3333f5e35 100644 --- a/src/rules/dot_product/marginals.jl +++ b/src/rules/dot_product/marginals.jl @@ -1,5 +1,5 @@ -@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::AbstractCorrection) = begin +@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin # Forward message towards `in2` edge mf_in2 = @call_rule typeof(dot)(:in2, Marginalisation) (m_out = m_out, m_in1 = m_in1, meta = meta) @@ -8,7 +8,7 @@ return convert_paramfloattype((in1 = m_in1, in2 = q_in2)) end -@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::AbstractCorrection) = begin +@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin symmetric = @call_marginalrule typeof(dot)(:in1_in2) (m_out = m_out, m_in1 = m_in2, m_in2 = m_in1, meta = meta) return convert_paramfloattype((in1 = symmetric[:in2], in2 = symmetric[:in1])) end diff --git a/src/rules/dot_product/out.jl b/src/rules/dot_product/out.jl index aa2f78d8e..68babf4d9 100644 --- a/src/rules/dot_product/out.jl +++ b/src/rules/dot_product/out.jl @@ -1,9 +1,9 @@ -@rule typeof(dot)(:out, Marginalisation) (m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::AbstractCorrection) = begin +@rule typeof(dot)(:out, Marginalisation) (m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin return @call_rule typeof(dot)(:out, Marginalisation) (m_in1 = m_in2, m_in2 = m_in1, meta = meta) end -@rule typeof(dot)(:out, Marginalisation) (m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::AbstractCorrection) = begin +@rule typeof(dot)(:out, Marginalisation) (m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin A = mean(m_in1) in2_mean, in2_cov = mean_cov(m_in2) return NormalMeanVariance(dot(A, in2_mean), dot(A, in2_cov, A)) diff --git a/src/rules/multiplication/A.jl b/src/rules/multiplication/A.jl index 47577fb12..5e3fd0b60 100644 --- a/src/rules/multiplication/A.jl +++ b/src/rules/multiplication/A.jl @@ -1,12 +1,12 @@ -@rule typeof(*)(:A, Marginalisation) (m_out::PointMass, m_in::PointMass, meta::Union{<:AbstractCorrection, Nothing}) = PointMass(mean(m_in) \ mean(m_out)) +@rule typeof(*)(:A, Marginalisation) (m_out::PointMass, m_in::PointMass, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = PointMass(mean(m_in) \ mean(m_out)) -@rule typeof(*)(:A, Marginalisation) (m_out::GammaDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:A, Marginalisation) (m_out::GammaDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin return GammaShapeRate(shape(m_out), rate(m_out) * mean(m_in)) end # if A is a matrix, then the result is multivariate -@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin A = mean(m_in) ξ_out, W_out = weightedmean_precision(m_out) W = correction!(meta, A' * W_out * A) @@ -15,7 +15,7 @@ end # if A is a vector, then the result is univariate # this rule links to the special case (AbstractVector * Univariate) for forward (:out) rule -@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin A = mean(m_in) ξ_out, W_out = weightedmean_precision(m_out) W = correction!(meta, dot(A, W_out, A)) @@ -23,7 +23,7 @@ end end # if A is a scalar, then the input is either univariate or multivariate -@rule typeof(*)(:A, Marginalisation) (m_out::F, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrection, Nothing}) where {F <: NormalDistributionsFamily} = begin +@rule typeof(*)(:A, Marginalisation) (m_out::F, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) where {F <: NormalDistributionsFamily} = begin A = mean(m_in) ξ_out, W_out = weightedmean_precision(m_out) W = correction!(meta, A^2 * W_out) @@ -31,7 +31,7 @@ end end # specialized versions for mean-covariance parameterization -@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin A = mean(m_in) μ_out, Σ_out = mean_cov(m_out) @@ -42,7 +42,7 @@ end return MvNormalWeightedMeanPrecision(tmp * μ_out, W) end -@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin A = mean(m_in) μ_out, Σ_out = mean_cov(m_out) @@ -53,14 +53,14 @@ end return NormalWeightedMeanPrecision(dot(tmp, μ_out), W) end -@rule typeof(*)(:A, Marginalisation) (m_out::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:A, Marginalisation) (m_out::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin μ_in, var_in = mean_var(m_in) μ_out, var_out = mean_var(m_out) log_backwardpass = (x) -> -log(abs(x)) - 0.5 * log(2π * (var_in + var_out / x^2)) - 1 / 2 * (μ_out - x * μ_in)^2 / (var_in * x^2 + var_out) return ContinuousUnivariateLogPdf(log_backwardpass) end -@rule typeof(*)(:A, Marginalisation) (m_out::UnivariateDistribution, m_in::UnivariateDistribution, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:A, Marginalisation) (m_out::UnivariateDistribution, m_in::UnivariateDistribution, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin nsamples = 3000 samples_in = rand(m_in, nsamples) p = make_inversedist_message(samples_in, m_out) diff --git a/src/rules/multiplication/in.jl b/src/rules/multiplication/in.jl index 4dca95d14..fb1727737 100644 --- a/src/rules/multiplication/in.jl +++ b/src/rules/multiplication/in.jl @@ -1,12 +1,12 @@ -@rule typeof(*)(:in, Marginalisation) (m_out::PointMass, m_A::PointMass, meta::Union{<:AbstractCorrection, Nothing}) = PointMass(mean(m_A) \ mean(m_out)) +@rule typeof(*)(:in, Marginalisation) (m_out::PointMass, m_A::PointMass, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = PointMass(mean(m_A) \ mean(m_out)) -@rule typeof(*)(:in, Marginalisation) (m_out::GammaDistributionsFamily, m_A::PointMass{<:Real}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:in, Marginalisation) (m_out::GammaDistributionsFamily, m_A::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin return GammaShapeRate(shape(m_out), rate(m_out) * mean(m_A)) end # if A is a matrix, then the result is multivariate -@rule typeof(*)(:in, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_A::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:in, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_A::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin A = mean(m_A) ξ_out, W_out = weightedmean_precision(m_out) W = correction!(meta, A' * W_out * A) @@ -15,7 +15,7 @@ end # if A is a vector, then the result is univariate # this rule links to the special case (AbstractVector * Univariate) for forward (:out) rule -@rule typeof(*)(:in, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_A::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:in, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_A::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin A = mean(m_A) ξ_out, W_out = weightedmean_precision(m_out) W = correction!(meta, dot(A, W_out, A)) @@ -23,7 +23,7 @@ end end # if A is a scalar, then the input is either univariate or multivariate -@rule typeof(*)(:in, Marginalisation) (m_out::F, m_A::PointMass{<:Real}, meta::Union{<:AbstractCorrection, Nothing}) where {F <: NormalDistributionsFamily} = begin +@rule typeof(*)(:in, Marginalisation) (m_out::F, m_A::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) where {F <: NormalDistributionsFamily} = begin A = mean(m_A) @logscale -logdet(A) ξ_out, W_out = weightedmean_precision(m_out) @@ -32,7 +32,7 @@ end end # specialized versions for mean-covariance parameterization -@rule typeof(*)(:in, Marginalisation) (m_out::MvNormalMeanCovariance, m_A::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:in, Marginalisation) (m_out::MvNormalMeanCovariance, m_A::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin A = mean(m_A) μ_out, Σ_out = mean_cov(m_out) @@ -43,7 +43,7 @@ end return MvNormalWeightedMeanPrecision(tmp * μ_out, W) end -@rule typeof(*)(:in, Marginalisation) (m_out::MvNormalMeanCovariance, m_A::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:in, Marginalisation) (m_out::MvNormalMeanCovariance, m_A::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin A = mean(m_A) μ_out, Σ_out = mean_cov(m_out) @@ -54,13 +54,13 @@ end return NormalWeightedMeanPrecision(dot(tmp, μ_out), W) end -@rule typeof(*)(:in, Marginalisation) (m_out::UnivariateGaussianDistributionsFamily, m_A::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:in, Marginalisation) (m_out::UnivariateGaussianDistributionsFamily, m_A::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin μ_A, var_A = mean_var(m_A) μ_out, var_out = mean_var(m_out) log_backwardpass = (x) -> -log(abs(x)) - 0.5 * log(2π * (var_A + var_out / x^2)) - 1 / 2 * (μ_out - x * μ_A)^2 / (var_A * x^2 + var_out) return ContinuousUnivariateLogPdf(log_backwardpass) end -@rule typeof(*)(:in, Marginalisation) (m_out::UnivariateDistribution, m_A::UnivariateDistribution, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:in, Marginalisation) (m_out::UnivariateDistribution, m_A::UnivariateDistribution, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin return @call_rule typeof(*)(:A, Marginalisation) (m_out = m_out, m_in = m_A, meta = meta) end diff --git a/src/rules/multiplication/marginals.jl b/src/rules/multiplication/marginals.jl index dee6d936c..fc27f8911 100644 --- a/src/rules/multiplication/marginals.jl +++ b/src/rules/multiplication/marginals.jl @@ -1,5 +1,5 @@ -@marginalrule typeof(*)(:A_in) (m_out::NormalDistributionsFamily, m_A::PointMass, m_in::NormalDistributionsFamily, meta::Union{<:AbstractCorrection, Nothing}) = begin +@marginalrule typeof(*)(:A_in) (m_out::NormalDistributionsFamily, m_A::PointMass, m_in::NormalDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin b_in = @call_rule typeof(*)(:in, Marginalisation) (m_out = m_out, m_A = m_A, meta = meta) q_in = prod(ProdAnalytical(), b_in, m_in) return (A = m_A, in = q_in) @@ -9,7 +9,7 @@ end # Note that for multivariate case in general multiplication is not a commutative operation, # but for scalars we make an exception @marginalrule typeof(*)(:A_in) ( - m_out::UnivariateNormalDistributionsFamily, m_A::UnivariateNormalDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrection, Nothing} + m_out::UnivariateNormalDistributionsFamily, m_A::UnivariateNormalDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing} ) = begin return @call_marginalrule typeof(*)(:A_in) (m_out = m_out, m_A = m_in, m_in = m_A, meta = meta) end diff --git a/src/rules/multiplication/out.jl b/src/rules/multiplication/out.jl index e022f80e8..c4494287e 100644 --- a/src/rules/multiplication/out.jl +++ b/src/rules/multiplication/out.jl @@ -1,23 +1,23 @@ import SpecialFunctions: besselk -@rule typeof(*)(:out, Marginalisation) (m_A::PointMass, m_in::PointMass, meta::Union{<:AbstractCorrection, Nothing}) = PointMass(mean(m_A) * mean(m_in)) +@rule typeof(*)(:out, Marginalisation) (m_A::PointMass, m_in::PointMass, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = PointMass(mean(m_A) * mean(m_in)) -@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::GammaDistributionsFamily, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::GammaDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin return GammaShapeRate(shape(m_in), rate(m_in) / mean(m_A)) end -@rule typeof(*)(:out, Marginalisation) (m_A::GammaDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:out, Marginalisation) (m_A::GammaDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin return @call_rule typeof(*)(:out, Marginalisation) (m_A = m_in, m_in = m_A, meta = meta, addons = getaddons()) # symmetric rule end -@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:AbstractMatrix}, m_in::F, meta::Union{<:AbstractCorrection, Nothing}) where {F <: NormalDistributionsFamily} = begin +@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:AbstractMatrix}, m_in::F, meta::Union{<:AbstractCorrectionStrategy, Nothing}) where {F <: NormalDistributionsFamily} = begin @logscale 0 A = mean(m_A) μ_in, Σ_in = mean_cov(m_in) return convert(promote_variate_type(F, NormalMeanVariance), A * μ_in, A * Σ_in * A') end -@rule typeof(*)(:out, Marginalisation) (m_A::F, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrection, Nothing}) where {F <: NormalDistributionsFamily} = begin +@rule typeof(*)(:out, Marginalisation) (m_A::F, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) where {F <: NormalDistributionsFamily} = begin return @call_rule typeof(*)(:out, Marginalisation) (m_A = m_in, m_in = m_A, meta = meta, addons = getaddons()) # symmetric rule end @@ -33,7 +33,7 @@ end # v out ~ Multivariate -> R^n # -->[x]--> # in1 ~ Univariate -> R^1 -@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:AbstractVector}, m_in::UnivariateNormalDistributionsFamily, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:AbstractVector}, m_in::UnivariateNormalDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin @logscale 0 a = mean(m_A) @@ -47,35 +47,35 @@ end return MvNormalMeanCovariance(μ, Σ) end -@rule typeof(*)(:out, Marginalisation) (m_A::UnivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:out, Marginalisation) (m_A::UnivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin return @call_rule typeof(*)(:out, Marginalisation) (m_A = m_in, m_in = m_A, meta = meta, addons = getaddons()) # symmetric rule end #------------------------ # Real * UnivariateNormalDistributions #------------------------ -@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::UnivariateNormalDistributionsFamily, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::UnivariateNormalDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin @logscale 0 a = mean(m_A) μ_in, v_in = mean_var(m_in) return NormalMeanVariance(a * μ_in, a^2 * v_in) end -@rule typeof(*)(:out, Marginalisation) (m_A::UnivariateNormalDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:out, Marginalisation) (m_A::UnivariateNormalDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin return @call_rule typeof(*)(:out, Marginalisation) (m_A = m_in, m_in = m_A, meta = meta, addons = getaddons()) # symmetric rule end #----------------------- # Univariate Normal * Univariate Normal #---------------------- -@rule typeof(*)(:out, Marginalisation) (m_A::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:out, Marginalisation) (m_A::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin μ_A, var_A = mean_var(m_A) μ_in, var_in = mean_var(m_in) return ContinuousUnivariateLogPdf(besselmod(μ_in, var_in, μ_A, var_A, 0.0)) end -@rule typeof(*)(:out, Marginalisation) (m_A::UnivariateDistribution, m_in::UnivariateDistribution, meta::Union{<:AbstractCorrection, Nothing}) = begin +@rule typeof(*)(:out, Marginalisation) (m_A::UnivariateDistribution, m_in::UnivariateDistribution, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin nsamples = 3000 samples_A = rand(m_A, nsamples) p = make_productdist_message(samples_A, m_in) diff --git a/test/algebra/test_correction.jl b/test/algebra/test_correction.jl deleted file mode 100644 index d5f9774c7..000000000 --- a/test/algebra/test_correction.jl +++ /dev/null @@ -1,62 +0,0 @@ -module ReactiveMPCorrectionTest - -using Test -using ReactiveMP -using Random - -using LinearAlgebra - -@testset "Correction" begin - rng = MersenneTwister(1234) - - for n in [3, 5, 10] - a = rand(rng) - A = rand(rng, n, n) - - b = ReactiveMP.correction!(NoCorrection(), a) - B = ReactiveMP.correction!(NoCorrection(), A) - - @test a == b - @test a === b - @test A == B - @test A === B - - c = ReactiveMP.correction!(TinyCorrection(), ReactiveMP.tiny / 2) - @test c >= ReactiveMP.tiny - - C = ReactiveMP.correction!(TinyCorrection(), copy(A)) - - @test A ≈ C - @test mapreduce((d) -> d[1] + ReactiveMP.tiny === d[2], &, zip(diag(A), diag(C))) - - D = rand(rng, n, n) - E = ReactiveMP.correction!(TinyCorrection(), D) - @test D === E - - v = 1e-10 * rand(rng) - - f = ReactiveMP.correction!(FixedCorrection(v), v / 2) - @test f >= v - - F = ReactiveMP.correction!(FixedCorrection(v), copy(A)) - @test A ≈ F - @test mapreduce((d) -> d[1] + v === d[2], &, zip(diag(A), diag(F))) - - G = rand(rng, n, n) - H = ReactiveMP.correction!(FixedCorrection(v), G) - @test G === H - - j = ReactiveMP.correction!(ClampEigenValuesCorrection(10.0), 5.0) - @test j >= 10.0 - - J = ReactiveMP.correction!(ClampEigenValuesCorrection(10.0), copy(A)) - S_J = svd(J) - - @test mapreduce((d) -> d >= 10.0 || d ≈ 10.0, &, S_J.S) - - K = ReactiveMP.correction!(ClampEigenValuesCorrection(1e-12), copy(A)) - @test K ≈ A - end -end - -end diff --git a/test/rules/dot_product/test_in1.jl b/test/rules/dot_product/test_in1.jl index 34b4b7d9e..2260a06dc 100644 --- a/test/rules/dot_product/test_in1.jl +++ b/test/rules/dot_product/test_in1.jl @@ -6,8 +6,10 @@ using Random import ReactiveMP: @test_rules import LinearAlgebra: dot +import MatrixCorrectionTools: NoCorrection, AddToDiagonalEntries, ReplaceZeroDiagonalEntries @testset "rules:typeof(dot):in1" begin + @testset "Belief Propagation: (m_out::UnivariateNormalDistributionsFamily, m_in2::PointMass)" begin @test_rules [check_type_promotion = true] typeof(dot)(:in1, Marginalisation) [ (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in2 = PointMass(-1.0), meta = NoCorrection()), output = NormalWeightedMeanPrecision(-1.0, 0.5)), @@ -16,9 +18,9 @@ import LinearAlgebra: dot ] @test_rules [check_type_promotion = true] typeof(dot)(:in1, Marginalisation) [ - (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in2 = PointMass(-1.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(-1.0, 0.5)), - (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in2 = PointMass(-2.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(-2.0, 4.0)), - (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in2 = PointMass(-1.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(-2.0, 1.0)) + (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in2 = PointMass(-1.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-1.0, 0.5)), + (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in2 = PointMass(-2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-2.0, 4.0)), + (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in2 = PointMass(-1.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-2.0, 1.0)) ] @test_rules [check_type_promotion = true] typeof(dot)(:in1, Marginalisation) [ @@ -28,9 +30,9 @@ import LinearAlgebra: dot ] @test_rules [check_type_promotion = true] typeof(dot)(:in1, Marginalisation) [ - (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in2 = PointMass(0.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(0.0, tiny)), - (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in2 = PointMass(0.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(0.0, tiny)), - (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in2 = PointMass(0.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(0.0, tiny)) + (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in2 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)), + (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in2 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)), + (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in2 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)) ] end @@ -50,17 +52,32 @@ import LinearAlgebra: dot ) ] - @test_rules [check_type_promotion = false] typeof(dot)(:in1, Marginalisation) [ + @test_rules [check_type_promotion = true] typeof(dot)(:in1, Marginalisation) [ + ( + input = (m_out = NormalMeanVariance(2.0, 1.0), m_in2 = PointMass([-1.0, 2.0]), meta = ReplaceZeroDiagonalEntries(tiny)), + output = MvNormalWeightedMeanPrecision([-2.0, 4.0], [1.0 -2.0; -2.0 4.0]) + ), + ( + input = (m_out = NormalMeanPrecision(1.0, inv(2.0)), m_in2 = PointMass([1.0, 1.0]), meta = ReplaceZeroDiagonalEntries(tiny)), + output = MvNormalWeightedMeanPrecision([0.5, 0.5], [0.5 0.5; 0.5 0.5]) + ), + ( + input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in2 = PointMass([-2.0, 3.0]), meta = ReplaceZeroDiagonalEntries(tiny)), + output = MvNormalWeightedMeanPrecision([-4.0, 6.0], [4.0 -6.0; -6.0 9.0]) + ) + ] + + @test_rules [check_type_promotion = true] typeof(dot)(:in1, Marginalisation) [ ( - input = (m_out = NormalMeanVariance(2.0, 1.0), m_in2 = PointMass([-1.0, 2.0]), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(2.0, 1.0), m_in2 = PointMass([-1.0, 2.0]), meta = AddToDiagonalEntries(tiny)), output = MvNormalWeightedMeanPrecision([-2.0, 4.0], [1.0+tiny -2.0; -2.0 4.0+tiny]) ), ( - input = (m_out = NormalMeanPrecision(1.0, inv(2.0)), m_in2 = PointMass([1.0, 1.0]), meta = TinyCorrection()), + input = (m_out = NormalMeanPrecision(1.0, inv(2.0)), m_in2 = PointMass([1.0, 1.0]), meta = AddToDiagonalEntries(tiny)), output = MvNormalWeightedMeanPrecision([0.5, 0.5], [0.5+tiny 0.5; 0.5 0.5+tiny]) ), ( - input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in2 = PointMass([-2.0, 3.0]), meta = TinyCorrection()), + input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in2 = PointMass([-2.0, 3.0]), meta = AddToDiagonalEntries(tiny)), output = MvNormalWeightedMeanPrecision([-4.0, 6.0], [4.0+tiny -6.0; -6.0 9.0+tiny]) ) ] diff --git a/test/rules/dot_product/test_in2.jl b/test/rules/dot_product/test_in2.jl index 9499247f5..a2b12d35e 100644 --- a/test/rules/dot_product/test_in2.jl +++ b/test/rules/dot_product/test_in2.jl @@ -5,6 +5,7 @@ using ReactiveMP using Random import ReactiveMP: @test_rules import LinearAlgebra: dot +import MatrixCorrectionTools: NoCorrection, AddToDiagonalEntries, ReplaceZeroDiagonalEntries @testset "rules:typeof(dot):in2" begin @testset "Belief Propagation: (m_out::UnivariateNormalDistributionsFamily, m_in1::PointMass)" begin @@ -15,9 +16,9 @@ import LinearAlgebra: dot ] @test_rules [check_type_promotion = true] typeof(dot)(:in2, Marginalisation) [ - (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in1 = PointMass(-1.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(-1.0, 0.5)), - (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in1 = PointMass(-2.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(-2.0, 4.0)), - (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in1 = PointMass(-1.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(-2.0, 1.0)) + (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in1 = PointMass(-1.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-1.0, 0.5)), + (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in1 = PointMass(-2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-2.0, 4.0)), + (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in1 = PointMass(-1.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-2.0, 1.0)) ] @test_rules [check_type_promotion = true] typeof(dot)(:in2, Marginalisation) [ @@ -27,9 +28,9 @@ import LinearAlgebra: dot ] @test_rules [check_type_promotion = true] typeof(dot)(:in2, Marginalisation) [ - (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in1 = PointMass(0.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(0.0, tiny)), - (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in1 = PointMass(0.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(0.0, tiny)), - (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in1 = PointMass(0.0), meta = TinyCorrection()), output = NormalWeightedMeanPrecision(0.0, tiny)) + (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in1 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)), + (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in1 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)), + (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in1 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)) ] end @@ -49,17 +50,32 @@ import LinearAlgebra: dot ) ] - @test_rules [check_type_promotion = false] typeof(dot)(:in2, Marginalisation) [ + @test_rules [check_type_promotion = true] typeof(dot)(:in2, Marginalisation) [ + ( + input = (m_out = NormalMeanVariance(2.0, 2.0), m_in1 = PointMass([-1.0, 1.0]), meta = ReplaceZeroDiagonalEntries(tiny)), + output = MvNormalWeightedMeanPrecision([-1.0, 1.0], [0.5 -0.5; -0.5 0.5]) + ), + ( + input = (m_out = NormalMeanPrecision(1.0, inv(2.0)), m_in1 = PointMass([2.0, 1.0]), meta = ReplaceZeroDiagonalEntries(tiny)), + output = MvNormalWeightedMeanPrecision([1.0, 0.5], [2.0 1.0; 1.0 0.5]) + ), + ( + input = (m_out = NormalWeightedMeanPrecision(1.0, 1.0), m_in1 = PointMass([-1.0, 3.0]), meta = ReplaceZeroDiagonalEntries(tiny)), + output = MvNormalWeightedMeanPrecision([-1.0, 3.0], [1.0 -3.0; -3.0 9.0]) + ) + ] + + @test_rules [check_type_promotion = true] typeof(dot)(:in2, Marginalisation) [ ( - input = (m_out = NormalMeanVariance(2.0, 2.0), m_in1 = PointMass([-1.0, 1.0]), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(2.0, 2.0), m_in1 = PointMass([-1.0, 1.0]), meta = AddToDiagonalEntries(tiny)), output = MvNormalWeightedMeanPrecision([-1.0, 1.0], [0.5+tiny -0.5; -0.5 0.5+tiny]) ), ( - input = (m_out = NormalMeanPrecision(1.0, inv(2.0)), m_in1 = PointMass([2.0, 1.0]), meta = TinyCorrection()), + input = (m_out = NormalMeanPrecision(1.0, inv(2.0)), m_in1 = PointMass([2.0, 1.0]), meta = AddToDiagonalEntries(tiny)), output = MvNormalWeightedMeanPrecision([1.0, 0.5], [2.0+tiny 1.0; 1.0 0.5+tiny]) ), ( - input = (m_out = NormalWeightedMeanPrecision(1.0, 1.0), m_in1 = PointMass([-1.0, 3.0]), meta = TinyCorrection()), + input = (m_out = NormalWeightedMeanPrecision(1.0, 1.0), m_in1 = PointMass([-1.0, 3.0]), meta = AddToDiagonalEntries(tiny)), output = MvNormalWeightedMeanPrecision([-1.0, 3.0], [1.0+tiny -3.0; -3.0 9.0+tiny]) ) ] diff --git a/test/rules/dot_product/test_marginals.jl b/test/rules/dot_product/test_marginals.jl index dcdad58f6..b8101e780 100644 --- a/test/rules/dot_product/test_marginals.jl +++ b/test/rules/dot_product/test_marginals.jl @@ -6,6 +6,7 @@ using Random import ReactiveMP: @test_marginalrules import LinearAlgebra: dot +import MatrixCorrectionTools: NoCorrection, ReplaceZeroDiagonalEntries @testset "marginalrules:DotProduct" begin @testset "in1_in2: (m_out::UnivariateNormalDistributionsFamily, m_in1::PointMass, m_in2::UnivariateNormalDistributionsFamily)" begin @@ -24,17 +25,17 @@ import LinearAlgebra: dot ) ] - @test_marginalrules [check_type_promotion = false] typeof(dot)(:in1_in2) [ + @test_marginalrules [check_type_promotion = true] typeof(dot)(:in1_in2) [ ( - input = (m_out = NormalMeanVariance(1.0, 2.0), m_in1 = PointMass(1.0), m_in2 = NormalMeanVariance(2.0, 2.0), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(1.0, 2.0), m_in1 = PointMass(1.0), m_in2 = NormalMeanVariance(2.0, 2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(1.0), in2 = NormalWeightedMeanPrecision(1.5, 1.0)) ), ( - input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanVariance(1.0, 2.0), meta = TinyCorrection()), + input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanVariance(1.0, 2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(2.0), in2 = NormalWeightedMeanPrecision(2.5, 4.5)) ), ( - input = (m_out = NormalWeightedMeanPrecision(0.5, 0.5), m_in1 = PointMass(-2.0), m_in2 = NormalMeanVariance(1.0, 2.0), meta = TinyCorrection()), + input = (m_out = NormalWeightedMeanPrecision(0.5, 0.5), m_in1 = PointMass(-2.0), m_in2 = NormalMeanVariance(1.0, 2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(-2.0), in2 = NormalWeightedMeanPrecision(-0.5, 2.5)) ) ] @@ -54,17 +55,17 @@ import LinearAlgebra: dot ) ] - @test_marginalrules [check_type_promotion = false] typeof(dot)(:in1_in2) [ + @test_marginalrules [check_type_promotion = true] typeof(dot)(:in1_in2) [ ( - input = (m_out = NormalMeanVariance(1.0, 2.0), m_in1 = PointMass(1.0), m_in2 = NormalMeanVariance(2.0, 2.0), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(1.0, 2.0), m_in1 = PointMass(1.0), m_in2 = NormalMeanVariance(2.0, 2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(1.0), in2 = NormalWeightedMeanPrecision(1.5, 1.0)) ), ( - input = (m_out = NormalMeanVariance(1.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanPrecision(1.0, 0.5), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(1.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanPrecision(1.0, 0.5), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(2.0), in2 = NormalWeightedMeanPrecision(2.5, 4.5)) ), ( - input = (m_out = NormalMeanVariance(1.0, 2.0), m_in1 = PointMass(-2.0), m_in2 = NormalWeightedMeanPrecision(0.5, 0.5), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(1.0, 2.0), m_in1 = PointMass(-2.0), m_in2 = NormalWeightedMeanPrecision(0.5, 0.5), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(-2.0), in2 = NormalWeightedMeanPrecision(-0.5, 2.5)) ) ] @@ -86,17 +87,17 @@ import LinearAlgebra: dot ) ] - @test_marginalrules [check_type_promotion = false] typeof(dot)(:in1_in2) [ + @test_marginalrules [check_type_promotion = true] typeof(dot)(:in1_in2) [ ( - input = (m_out = NormalMeanVariance(1.0, 1.0), m_in1 = NormalMeanVariance(1.0, 2.0), m_in2 = PointMass(2.0), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(1.0, 1.0), m_in1 = NormalMeanVariance(1.0, 2.0), m_in2 = PointMass(2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = NormalWeightedMeanPrecision(2.5, 4.5), in2 = PointMass(2.0)) ), ( - input = (m_out = NormalMeanPrecision(3.0, 1.0), m_in1 = NormalMeanVariance(1.0, 2.0), m_in2 = PointMass(2.0), meta = TinyCorrection()), + input = (m_out = NormalMeanPrecision(3.0, 1.0), m_in1 = NormalMeanVariance(1.0, 2.0), m_in2 = PointMass(2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = NormalWeightedMeanPrecision(6.5, 4.5), in2 = PointMass(2.0)) ), ( - input = (m_out = NormalWeightedMeanPrecision(4.0, 1.0), m_in1 = NormalMeanVariance(1.0, 3.0), m_in2 = PointMass(2.0), meta = TinyCorrection()), + input = (m_out = NormalWeightedMeanPrecision(4.0, 1.0), m_in1 = NormalMeanVariance(1.0, 3.0), m_in2 = PointMass(2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = NormalWeightedMeanPrecision(25 / 3, 13 / 3), in2 = PointMass(2.0)) ) ] @@ -118,17 +119,17 @@ import LinearAlgebra: dot ) ] - @test_marginalrules [check_type_promotion = false] typeof(dot)(:in1_in2) [ + @test_marginalrules [check_type_promotion = true] typeof(dot)(:in1_in2) [ ( - input = (m_out = NormalMeanVariance(1.0, 1.0), m_in1 = NormalMeanVariance(1.0, 2.0), m_in2 = PointMass(2.0), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(1.0, 1.0), m_in1 = NormalMeanVariance(1.0, 2.0), m_in2 = PointMass(2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = NormalWeightedMeanPrecision(2.5, 4.5), in2 = PointMass(2.0)) ), ( - input = (m_out = NormalMeanVariance(3.0, 1.0), m_in1 = NormalMeanPrecision(1.0, 0.5), m_in2 = PointMass(2.0), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(3.0, 1.0), m_in1 = NormalMeanPrecision(1.0, 0.5), m_in2 = PointMass(2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = NormalWeightedMeanPrecision(6.5, 4.5), in2 = PointMass(2.0)) ), ( - input = (m_out = NormalMeanVariance(4.0, 1.0), m_in1 = NormalWeightedMeanPrecision(1.0 / 3.0, 1.0 / 3.0), m_in2 = PointMass(2.0), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(4.0, 1.0), m_in1 = NormalWeightedMeanPrecision(1.0 / 3.0, 1.0 / 3.0), m_in2 = PointMass(2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = NormalWeightedMeanPrecision(25 / 3, 13 / 3), in2 = PointMass(2.0)) ) ] @@ -152,15 +153,15 @@ import LinearAlgebra: dot @test_marginalrules [check_type_promotion = false] typeof(dot)(:in1_in2) [ ( - input = (m_out = NormalMeanVariance(1.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanVariance(1.0, 2.0), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(1.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanVariance(1.0, 2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(2.0), in2 = NormalWeightedMeanPrecision(2.5, 4.5)) ), ( - input = (m_out = NormalMeanPrecision(3.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanVariance(1.0, 2.0), meta = TinyCorrection()), + input = (m_out = NormalMeanPrecision(3.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanVariance(1.0, 2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(2.0), in2 = NormalWeightedMeanPrecision(6.5, 4.5)) ), ( - input = (m_out = NormalWeightedMeanPrecision(4.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanVariance(1.0, 3.0), meta = TinyCorrection()), + input = (m_out = NormalWeightedMeanPrecision(4.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanVariance(1.0, 3.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(2.0), in2 = NormalWeightedMeanPrecision(25 / 3, 13 / 3)) ) ] @@ -182,17 +183,17 @@ import LinearAlgebra: dot ) ] - @test_marginalrules [check_type_promotion = false] typeof(dot)(:in1_in2) [ + @test_marginalrules [check_type_promotion = true] typeof(dot)(:in1_in2) [ ( - input = (m_out = NormalMeanVariance(1.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanVariance(1.0, 2.0), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(1.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanVariance(1.0, 2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(2.0), in2 = NormalWeightedMeanPrecision(2.5, 4.5)) ), ( - input = (m_out = NormalMeanVariance(3.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanPrecision(1.0, 0.5), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(3.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalMeanPrecision(1.0, 0.5), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(2.0), in2 = NormalWeightedMeanPrecision(6.5, 4.5)) ), ( - input = (m_out = NormalMeanVariance(4.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalWeightedMeanPrecision(1.0 / 3.0, 1.0 / 3.0), meta = TinyCorrection()), + input = (m_out = NormalMeanVariance(4.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalWeightedMeanPrecision(1.0 / 3.0, 1.0 / 3.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = (in1 = PointMass(2.0), in2 = NormalWeightedMeanPrecision(25 / 3, 13 / 3)) ) ] diff --git a/test/rules/dot_product/test_out.jl b/test/rules/dot_product/test_out.jl index cac35402a..1040b9dea 100644 --- a/test/rules/dot_product/test_out.jl +++ b/test/rules/dot_product/test_out.jl @@ -6,6 +6,7 @@ using Random import ReactiveMP: @test_rules import LinearAlgebra: dot +import MatrixCorrectionTools: NoCorrection, ReplaceZeroDiagonalEntries @testset "rules:typeof(dot):out" begin @testset "Belief Propagation: (m_in1::PointMass, m_in2::NormalDistributionsFamily)" begin @@ -16,9 +17,9 @@ import LinearAlgebra: dot ] @test_rules [check_type_promotion = true] typeof(dot)(:out, Marginalisation) [ - (input = (m_in1 = PointMass(1.0), m_in2 = NormalMeanVariance(2.0, 2.0), meta = TinyCorrection()), output = NormalMeanVariance(2.0, 2.0)), - (input = (m_in1 = PointMass(-1.0), m_in2 = NormalMeanPrecision(3.0, 1.0), meta = TinyCorrection()), output = NormalMeanVariance(-3.0, 1.0)), - (input = (m_in1 = PointMass(2.0), m_in2 = NormalWeightedMeanPrecision(2.0, 0.5), meta = TinyCorrection()), output = NormalMeanVariance(8.0, 8.0)) + (input = (m_in1 = PointMass(1.0), m_in2 = NormalMeanVariance(2.0, 2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(2.0, 2.0)), + (input = (m_in1 = PointMass(-1.0), m_in2 = NormalMeanPrecision(3.0, 1.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(-3.0, 1.0)), + (input = (m_in1 = PointMass(2.0), m_in2 = NormalWeightedMeanPrecision(2.0, 0.5), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(8.0, 8.0)) ] end @@ -30,24 +31,24 @@ import LinearAlgebra: dot ] @test_rules [check_type_promotion = true] typeof(dot)(:out, Marginalisation) [ - (input = (m_in1 = NormalMeanVariance(2.0, 2.0), m_in2 = PointMass(4.0), meta = TinyCorrection()), output = NormalMeanVariance(8.0, 32.0)), - (input = (m_in1 = NormalMeanPrecision(2.0, inv(3.0)), m_in2 = PointMass(2.0), meta = TinyCorrection()), output = NormalMeanVariance(4.0, 12.0)), - (input = (m_in1 = NormalWeightedMeanPrecision(2.0, 0.5), m_in2 = PointMass(1.0), meta = TinyCorrection()), output = NormalMeanVariance(4.0, 2.0)) + (input = (m_in1 = NormalMeanVariance(2.0, 2.0), m_in2 = PointMass(4.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(8.0, 32.0)), + (input = (m_in1 = NormalMeanPrecision(2.0, inv(3.0)), m_in2 = PointMass(2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(4.0, 12.0)), + (input = (m_in1 = NormalWeightedMeanPrecision(2.0, 0.5), m_in2 = PointMass(1.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(4.0, 2.0)) ] end @testset "Belief Propagation: (m_in1::MultivariateNormalDistributionsFamily, m_in2::PointMass)" begin @test_rules [check_type_promotion = true] typeof(dot)(:out, Marginalisation) [ ( - input = (m_in1 = MvNormalMeanCovariance([-1.0, 1.0], [2.0 -1.0; -1.0 4.0]), m_in2 = PointMass([4.0, 1.0]), meta = TinyCorrection()), + input = (m_in1 = MvNormalMeanCovariance([-1.0, 1.0], [2.0 -1.0; -1.0 4.0]), m_in2 = PointMass([4.0, 1.0]), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(-3, 28) ), ( - input = (m_in1 = MvNormalMeanPrecision([2.0, 1.0], [2.0 -0.5; -0.5 5.0]), m_in2 = PointMass([2.0, 2.0]), meta = TinyCorrection()), + input = (m_in1 = MvNormalMeanPrecision([2.0, 1.0], [2.0 -0.5; -0.5 5.0]), m_in2 = PointMass([2.0, 2.0]), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(6.0, 128 / 39) ), ( - input = (m_in1 = MvNormalWeightedMeanPrecision([3.0, 2.0], [10.0 1.0; 1.0 20.0]), m_in2 = PointMass([-1.0, 3.0]), meta = TinyCorrection()), + input = (m_in1 = MvNormalWeightedMeanPrecision([3.0, 2.0], [10.0 1.0; 1.0 20.0]), m_in2 = PointMass([-1.0, 3.0]), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(-7 / 199, 116 / 199) ) ] @@ -71,15 +72,15 @@ import LinearAlgebra: dot @testset "Belief Propagation: (m_in1::PointMass, m_in2::MultivariateNormalDistributionsFamily)" begin @test_rules [check_type_promotion = true] typeof(dot)(:out, Marginalisation) [ ( - input = (m_in1 = PointMass([4.0, 1.0]), m_in2 = MvNormalMeanCovariance([-1.0, 1.0], [2.0 -1.0; -1.0 4.0]), meta = TinyCorrection()), + input = (m_in1 = PointMass([4.0, 1.0]), m_in2 = MvNormalMeanCovariance([-1.0, 1.0], [2.0 -1.0; -1.0 4.0]), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(-3, 28) ), ( - input = (m_in1 = PointMass([2.0, 2.0]), m_in2 = MvNormalMeanPrecision([2.0, 1.0], [2.0 -0.5; -0.5 5.0]), meta = TinyCorrection()), + input = (m_in1 = PointMass([2.0, 2.0]), m_in2 = MvNormalMeanPrecision([2.0, 1.0], [2.0 -0.5; -0.5 5.0]), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(6.0, 128 / 39) ), ( - input = (m_in1 = PointMass([-1.0, 3.0]), m_in2 = MvNormalWeightedMeanPrecision([3.0, 2.0], [10.0 1.0; 1.0 20.0]), meta = TinyCorrection()), + input = (m_in1 = PointMass([-1.0, 3.0]), m_in2 = MvNormalWeightedMeanPrecision([3.0, 2.0], [10.0 1.0; 1.0 20.0]), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalMeanVariance(-7 / 199, 116 / 199) ) ] diff --git a/test/rules/multiplication/test_A.jl b/test/rules/multiplication/test_A.jl index 856284351..be30a4237 100644 --- a/test/rules/multiplication/test_A.jl +++ b/test/rules/multiplication/test_A.jl @@ -11,9 +11,9 @@ import ReactiveMP: make_inversedist_message d1 = NormalMeanVariance(0.0, 1.0) d2 = NormalMeanVariance(0.5, 1.5) d3 = NormalMeanVariance(2.0, 0.5) - OutMessage_1 = @call_rule typeof(*)(:A, Marginalisation) (m_out = d1, m_in = d2, meta = TinyCorrection()) - OutMessage_2 = @call_rule typeof(*)(:A, Marginalisation) (m_out = d1, m_in = d3, meta = TinyCorrection()) - OutMessage_3 = @call_rule typeof(*)(:A, Marginalisation) (m_out = d2, m_in = d3, meta = TinyCorrection()) + OutMessage_1 = @call_rule typeof(*)(:A, Marginalisation) (m_out = d1, m_in = d2) + OutMessage_2 = @call_rule typeof(*)(:A, Marginalisation) (m_out = d1, m_in = d3) + OutMessage_3 = @call_rule typeof(*)(:A, Marginalisation) (m_out = d2, m_in = d3) groundtruthOutMessage_1 = (x) -> -log(abs(x)) - 0.5 * log(2π * (var(d2) + var(d1) / x^2)) - 1 / 2 * (mean(d1) - x * mean(d2))^2 / (var(d2) * x^2 + var(d1)) groundtruthOutMessage_2 = (x) -> -log(abs(x)) - 0.5 * log(2π * (var(d3) + var(d1) / x^2)) - 1 / 2 * (mean(d1) - x * mean(d3))^2 / (var(d3) * x^2 + var(d1)) groundtruthOutMessage_3 = (x) -> -log(abs(x)) - 0.5 * log(2π * (var(d3) + var(d2) / x^2)) - 1 / 2 * (mean(d2) - x * mean(d3))^2 / (var(d3) * x^2 + var(d2)) @@ -36,7 +36,7 @@ import ReactiveMP: make_inversedist_message num_samples = 3000 samples_d2 = rand(rng, d2, num_samples) - OutMessage = @call_rule typeof(*)(:A, Marginalisation) (m_out = d1, m_in = d2, meta = TinyCorrection()) + OutMessage = @call_rule typeof(*)(:A, Marginalisation) (m_out = d1, m_in = d2) @test typeof(OutMessage) <: ContinuousUnivariateLogPdf diff --git a/test/rules/multiplication/test_in.jl b/test/rules/multiplication/test_in.jl index 0b38b66ef..88c65d23e 100644 --- a/test/rules/multiplication/test_in.jl +++ b/test/rules/multiplication/test_in.jl @@ -11,9 +11,9 @@ import ReactiveMP: @test_rules, make_inversedist_message d1 = NormalMeanVariance(0.0, 1.0) d2 = NormalMeanVariance(0.5, 1.5) d3 = NormalMeanVariance(2.0, 0.5) - OutMessage_1 = @call_rule typeof(*)(:in, Marginalisation) (m_out = d1, m_A = d2, meta = TinyCorrection()) - OutMessage_2 = @call_rule typeof(*)(:in, Marginalisation) (m_out = d1, m_A = d3, meta = TinyCorrection()) - OutMessage_3 = @call_rule typeof(*)(:in, Marginalisation) (m_out = d2, m_A = d3, meta = TinyCorrection()) + OutMessage_1 = @call_rule typeof(*)(:in, Marginalisation) (m_out = d1, m_A = d2) + OutMessage_2 = @call_rule typeof(*)(:in, Marginalisation) (m_out = d1, m_A = d3) + OutMessage_3 = @call_rule typeof(*)(:in, Marginalisation) (m_out = d2, m_A = d3) groundtruthOutMessage_1 = (x) -> -log(abs(x)) - 0.5 * log(2π * (var(d2) + var(d1) / x^2)) - 1 / 2 * (mean(d1) - x * mean(d2))^2 / (var(d2) * x^2 + var(d1)) groundtruthOutMessage_2 = (x) -> -log(abs(x)) - 0.5 * log(2π * (var(d3) + var(d1) / x^2)) - 1 / 2 * (mean(d1) - x * mean(d3))^2 / (var(d3) * x^2 + var(d1)) groundtruthOutMessage_3 = (x) -> -log(abs(x)) - 0.5 * log(2π * (var(d3) + var(d2) / x^2)) - 1 / 2 * (mean(d2) - x * mean(d3))^2 / (var(d3) * x^2 + var(d2)) @@ -36,7 +36,7 @@ import ReactiveMP: @test_rules, make_inversedist_message num_samples = 3000 samples_d2 = rand(rng, d2, num_samples) - OutMessage = @call_rule typeof(*)(:in, Marginalisation) (m_out = d1, m_A = d2, meta = TinyCorrection()) + OutMessage = @call_rule typeof(*)(:in, Marginalisation) (m_out = d1, m_A = d2) @test typeof(OutMessage) <: ContinuousUnivariateLogPdf diff --git a/test/rules/multiplication/test_out.jl b/test/rules/multiplication/test_out.jl index 86655e33e..935769741 100644 --- a/test/rules/multiplication/test_out.jl +++ b/test/rules/multiplication/test_out.jl @@ -11,9 +11,9 @@ import ReactiveMP: @test_rules, besselmod, make_productdist_message d1 = NormalMeanVariance(0.0, 1.0) d2 = NormalMeanVariance(0.5, 1.5) d3 = NormalMeanVariance(2.0, 0.5) - OutMessage_1 = @call_rule typeof(*)(:out, Marginalisation) (m_A = d1, m_in = d2, meta = TinyCorrection()) - OutMessage_2 = @call_rule typeof(*)(:out, Marginalisation) (m_A = d1, m_in = d3, meta = TinyCorrection()) - OutMessage_3 = @call_rule typeof(*)(:out, Marginalisation) (m_A = d2, m_in = d3, meta = TinyCorrection()) + OutMessage_1 = @call_rule typeof(*)(:out, Marginalisation) (m_A = d1, m_in = d2) + OutMessage_2 = @call_rule typeof(*)(:out, Marginalisation) (m_A = d1, m_in = d3) + OutMessage_3 = @call_rule typeof(*)(:out, Marginalisation) (m_A = d2, m_in = d3) groundtruthOutMessage_1 = besselmod(mean(d1), var(d1), mean(d2), var(d2), 0.0) groundtruthOutMessage_2 = besselmod(mean(d1), var(d1), mean(d3), var(d3), 0.0) groundtruthOutMessage_3 = besselmod(mean(d2), var(d2), mean(d3), var(d3), 0.0) @@ -36,7 +36,7 @@ import ReactiveMP: @test_rules, besselmod, make_productdist_message num_samples = 3000 samples_d1 = rand(rng, d1, num_samples) - OutMessage = @call_rule typeof(*)(:out, Marginalisation) (m_A = d1, m_in = d2, meta = TinyCorrection()) + OutMessage = @call_rule typeof(*)(:out, Marginalisation) (m_A = d1, m_in = d2) @test typeof(OutMessage) <: ContinuousUnivariateLogPdf diff --git a/test/runtests.jl b/test/runtests.jl index 7a878f530..b99066392 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -200,7 +200,6 @@ end @test filename_to_key(key_to_filename("message")) == "message" end - addtests(testrunner, "algebra/test_correction.jl") addtests(testrunner, "algebra/test_common.jl") addtests(testrunner, "algebra/test_permutation_matrix.jl") addtests(testrunner, "algebra/test_standard_basis_vector.jl") From 77d5d67c62b618540cc08be8a4b03e489436aca4 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 25 Sep 2023 14:33:26 +0200 Subject: [PATCH 2/5] 2prev --- src/ReactiveMP.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index 0d8a6f644..878c27561 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -2,8 +2,9 @@ module ReactiveMP # List global dependencies here -using TinyHugeNumbers -using MatrixCorrectionTools: AbstractCorrectionStrategy, correction! +using TinyHugeNumbers, MatrixCorrectionTools + +import MatrixCorrectionTools: AbstractCorrectionStrategy, correction! # Reexport `tiny` and `huge` from the `TinyHugeNumbers` export tiny, huge From 2475a6fed149f21b58b70f2a2e0a113692f5d704 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 25 Sep 2023 16:48:24 +0200 Subject: [PATCH 3/5] small fix for the meta specification --- src/constraints/specifications/meta.jl | 44 ++++++++++++++------------ 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/src/constraints/specifications/meta.jl b/src/constraints/specifications/meta.jl index 4c0d18b59..5ee501140 100644 --- a/src/constraints/specifications/meta.jl +++ b/src/constraints/specifications/meta.jl @@ -65,30 +65,34 @@ See also: [`ConstraintsSpecification`](@ref) function resolve_meta(metaspec, fform, variables) symfform = as_node_symbol(fform) - var_names = map(name, TupleTools.flatten(variables)) - var_refs = map(resolve_variable_proxy, TupleTools.flatten(variables)) - var_refs_names = map(r -> r[1], var_refs) - found = nothing unrolled_foreach(getentries(metaspec)) do fentry # We iterate over all entries in the meta specification - if functionalform(fentry) === symfform && (all(s -> s ∈ var_names, getnames(fentry)) || all(s -> s ∈ var_refs_names, getnames(fentry))) - if isnothing(found) - # if we find an appropriate meta spec we simply set it - found = fentry - elseif !isnothing(found) && issubset(getnames(fentry), getnames(found)) && issubset(getnames(found), getnames(fentry)) - # The error case is the meta specification collision, two sets of names are exactly the same - error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).") - elseif !isnothing(found) && issubset(getnames(fentry), getnames(found)) - # If we find another matching meta spec, but it has fewer names in it we simply keep the previous one - nothing - elseif !isnothing(found) && issubset(getnames(found), getnames(fentry)) - # If we find another matching meta spec, and it has more names we override the previous one - found = fentry - elseif !isnothing(found) && !issubset(getnames(fentry), getnames(found)) && !issubset(getnames(found), getnames(fentry)) - # The error case is the meta specification collision, two sets of names are different and do not include each other - error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).") + if functionalform(fentry) === symfform + # The `var_names` & `var_refs_names` should be done only if we hit the required entry + # otherwise it would be too error prone, because many nodes cannot properly resolve their `var_names` (e.g. deterministic nodes with more than one input) + # but there might be no meta specification for such nodes, currently the algorithm recompute those for each hit, this can probably be improved + local var_names = map(name, TupleTools.flatten(variables)) + local var_refs = map(resolve_variable_proxy, TupleTools.flatten(variables)) + local var_refs_names = map(r -> r[1], var_refs) + if (all(s -> s ∈ var_names, getnames(fentry)) || all(s -> s ∈ var_refs_names, getnames(fentry))) + if isnothing(found) + # if we find an appropriate meta spec we simply set it + found = fentry + elseif !isnothing(found) && issubset(getnames(fentry), getnames(found)) && issubset(getnames(found), getnames(fentry)) + # The error case is the meta specification collision, two sets of names are exactly the same + error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).") + elseif !isnothing(found) && issubset(getnames(fentry), getnames(found)) + # If we find another matching meta spec, but it has fewer names in it we simply keep the previous one + nothing + elseif !isnothing(found) && issubset(getnames(found), getnames(fentry)) + # If we find another matching meta spec, and it has more names we override the previous one + found = fentry + elseif !isnothing(found) && !issubset(getnames(fentry), getnames(found)) && !issubset(getnames(found), getnames(fentry)) + # The error case is the meta specification collision, two sets of names are different and do not include each other + error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).") + end end end end From 117062f0f537efa6b9d111c09afe5724200f0fee Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 25 Sep 2023 16:54:53 +0200 Subject: [PATCH 4/5] bump version to 3.11 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d31cb00ba..758271cab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ReactiveMP" uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3" authors = ["Dmitry Bagaev ", "Albert Podusenko ", "Bart van Erp ", "Ismail Senoz "] -version = "3.10.0" +version = "3.11.0" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" From cc70acd00e1c5e35eaaf50853167174258304697 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Tue, 26 Sep 2023 11:32:41 +0200 Subject: [PATCH 5/5] style: make format --- src/rules/multiplication/A.jl | 4 +++- src/rules/multiplication/in.jl | 4 +++- src/rules/multiplication/out.jl | 24 ++++++++++++++---------- test/rules/dot_product/test_in1.jl | 11 ++++++++--- test/rules/dot_product/test_in2.jl | 10 ++++++++-- test/rules/dot_product/test_marginals.jl | 8 ++++++-- 6 files changed, 42 insertions(+), 19 deletions(-) diff --git a/src/rules/multiplication/A.jl b/src/rules/multiplication/A.jl index 5e3fd0b60..0768d0db8 100644 --- a/src/rules/multiplication/A.jl +++ b/src/rules/multiplication/A.jl @@ -53,7 +53,9 @@ end return NormalWeightedMeanPrecision(dot(tmp, μ_out), W) end -@rule typeof(*)(:A, Marginalisation) (m_out::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin +@rule typeof(*)(:A, Marginalisation) ( + m_out::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing} +) = begin μ_in, var_in = mean_var(m_in) μ_out, var_out = mean_var(m_out) log_backwardpass = (x) -> -log(abs(x)) - 0.5 * log(2π * (var_in + var_out / x^2)) - 1 / 2 * (μ_out - x * μ_in)^2 / (var_in * x^2 + var_out) diff --git a/src/rules/multiplication/in.jl b/src/rules/multiplication/in.jl index fb1727737..1ebc4db42 100644 --- a/src/rules/multiplication/in.jl +++ b/src/rules/multiplication/in.jl @@ -54,7 +54,9 @@ end return NormalWeightedMeanPrecision(dot(tmp, μ_out), W) end -@rule typeof(*)(:in, Marginalisation) (m_out::UnivariateGaussianDistributionsFamily, m_A::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin +@rule typeof(*)(:in, Marginalisation) ( + m_out::UnivariateGaussianDistributionsFamily, m_A::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing} +) = begin μ_A, var_A = mean_var(m_A) μ_out, var_out = mean_var(m_out) log_backwardpass = (x) -> -log(abs(x)) - 0.5 * log(2π * (var_A + var_out / x^2)) - 1 / 2 * (μ_out - x * μ_A)^2 / (var_A * x^2 + var_out) diff --git a/src/rules/multiplication/out.jl b/src/rules/multiplication/out.jl index c4494287e..1fc462ad4 100644 --- a/src/rules/multiplication/out.jl +++ b/src/rules/multiplication/out.jl @@ -10,16 +10,18 @@ end return @call_rule typeof(*)(:out, Marginalisation) (m_A = m_in, m_in = m_A, meta = meta, addons = getaddons()) # symmetric rule end -@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:AbstractMatrix}, m_in::F, meta::Union{<:AbstractCorrectionStrategy, Nothing}) where {F <: NormalDistributionsFamily} = begin - @logscale 0 - A = mean(m_A) - μ_in, Σ_in = mean_cov(m_in) - return convert(promote_variate_type(F, NormalMeanVariance), A * μ_in, A * Σ_in * A') -end +@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:AbstractMatrix}, m_in::F, meta::Union{<:AbstractCorrectionStrategy, Nothing}) where {F <: NormalDistributionsFamily} = + begin + @logscale 0 + A = mean(m_A) + μ_in, Σ_in = mean_cov(m_in) + return convert(promote_variate_type(F, NormalMeanVariance), A * μ_in, A * Σ_in * A') + end -@rule typeof(*)(:out, Marginalisation) (m_A::F, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) where {F <: NormalDistributionsFamily} = begin - return @call_rule typeof(*)(:out, Marginalisation) (m_A = m_in, m_in = m_A, meta = meta, addons = getaddons()) # symmetric rule -end +@rule typeof(*)(:out, Marginalisation) (m_A::F, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) where {F <: NormalDistributionsFamily} = + begin + return @call_rule typeof(*)(:out, Marginalisation) (m_A = m_in, m_in = m_A, meta = meta, addons = getaddons()) # symmetric rule + end #------------------------ # AbstractVector * UnivariateNormalDistributions @@ -68,7 +70,9 @@ end #----------------------- # Univariate Normal * Univariate Normal #---------------------- -@rule typeof(*)(:out, Marginalisation) (m_A::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin +@rule typeof(*)(:out, Marginalisation) ( + m_A::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing} +) = begin μ_A, var_A = mean_var(m_A) μ_in, var_in = mean_var(m_in) diff --git a/test/rules/dot_product/test_in1.jl b/test/rules/dot_product/test_in1.jl index 2260a06dc..8ad923297 100644 --- a/test/rules/dot_product/test_in1.jl +++ b/test/rules/dot_product/test_in1.jl @@ -9,7 +9,6 @@ import LinearAlgebra: dot import MatrixCorrectionTools: NoCorrection, AddToDiagonalEntries, ReplaceZeroDiagonalEntries @testset "rules:typeof(dot):in1" begin - @testset "Belief Propagation: (m_out::UnivariateNormalDistributionsFamily, m_in2::PointMass)" begin @test_rules [check_type_promotion = true] typeof(dot)(:in1, Marginalisation) [ (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in2 = PointMass(-1.0), meta = NoCorrection()), output = NormalWeightedMeanPrecision(-1.0, 0.5)), @@ -20,7 +19,10 @@ import MatrixCorrectionTools: NoCorrection, AddToDiagonalEntries, ReplaceZeroDia @test_rules [check_type_promotion = true] typeof(dot)(:in1, Marginalisation) [ (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in2 = PointMass(-1.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-1.0, 0.5)), (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in2 = PointMass(-2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-2.0, 4.0)), - (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in2 = PointMass(-1.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-2.0, 1.0)) + ( + input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in2 = PointMass(-1.0), meta = ReplaceZeroDiagonalEntries(tiny)), + output = NormalWeightedMeanPrecision(-2.0, 1.0) + ) ] @test_rules [check_type_promotion = true] typeof(dot)(:in1, Marginalisation) [ @@ -32,7 +34,10 @@ import MatrixCorrectionTools: NoCorrection, AddToDiagonalEntries, ReplaceZeroDia @test_rules [check_type_promotion = true] typeof(dot)(:in1, Marginalisation) [ (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in2 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)), (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in2 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)), - (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in2 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)) + ( + input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in2 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), + output = NormalWeightedMeanPrecision(0.0, tiny) + ) ] end diff --git a/test/rules/dot_product/test_in2.jl b/test/rules/dot_product/test_in2.jl index a2b12d35e..2ebd350e6 100644 --- a/test/rules/dot_product/test_in2.jl +++ b/test/rules/dot_product/test_in2.jl @@ -18,7 +18,10 @@ import MatrixCorrectionTools: NoCorrection, AddToDiagonalEntries, ReplaceZeroDia @test_rules [check_type_promotion = true] typeof(dot)(:in2, Marginalisation) [ (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in1 = PointMass(-1.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-1.0, 0.5)), (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in1 = PointMass(-2.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-2.0, 4.0)), - (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in1 = PointMass(-1.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(-2.0, 1.0)) + ( + input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in1 = PointMass(-1.0), meta = ReplaceZeroDiagonalEntries(tiny)), + output = NormalWeightedMeanPrecision(-2.0, 1.0) + ) ] @test_rules [check_type_promotion = true] typeof(dot)(:in2, Marginalisation) [ @@ -30,7 +33,10 @@ import MatrixCorrectionTools: NoCorrection, AddToDiagonalEntries, ReplaceZeroDia @test_rules [check_type_promotion = true] typeof(dot)(:in2, Marginalisation) [ (input = (m_out = NormalMeanVariance(2.0, 2.0), m_in1 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)), (input = (m_out = NormalMeanPrecision(1.0, 1.0), m_in1 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)), - (input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in1 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), output = NormalWeightedMeanPrecision(0.0, tiny)) + ( + input = (m_out = NormalWeightedMeanPrecision(2.0, 1.0), m_in1 = PointMass(0.0), meta = ReplaceZeroDiagonalEntries(tiny)), + output = NormalWeightedMeanPrecision(0.0, tiny) + ) ] end diff --git a/test/rules/dot_product/test_marginals.jl b/test/rules/dot_product/test_marginals.jl index b8101e780..2222bebd3 100644 --- a/test/rules/dot_product/test_marginals.jl +++ b/test/rules/dot_product/test_marginals.jl @@ -129,7 +129,9 @@ import MatrixCorrectionTools: NoCorrection, ReplaceZeroDiagonalEntries output = (in1 = NormalWeightedMeanPrecision(6.5, 4.5), in2 = PointMass(2.0)) ), ( - input = (m_out = NormalMeanVariance(4.0, 1.0), m_in1 = NormalWeightedMeanPrecision(1.0 / 3.0, 1.0 / 3.0), m_in2 = PointMass(2.0), meta = ReplaceZeroDiagonalEntries(tiny)), + input = ( + m_out = NormalMeanVariance(4.0, 1.0), m_in1 = NormalWeightedMeanPrecision(1.0 / 3.0, 1.0 / 3.0), m_in2 = PointMass(2.0), meta = ReplaceZeroDiagonalEntries(tiny) + ), output = (in1 = NormalWeightedMeanPrecision(25 / 3, 13 / 3), in2 = PointMass(2.0)) ) ] @@ -193,7 +195,9 @@ import MatrixCorrectionTools: NoCorrection, ReplaceZeroDiagonalEntries output = (in1 = PointMass(2.0), in2 = NormalWeightedMeanPrecision(6.5, 4.5)) ), ( - input = (m_out = NormalMeanVariance(4.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalWeightedMeanPrecision(1.0 / 3.0, 1.0 / 3.0), meta = ReplaceZeroDiagonalEntries(tiny)), + input = ( + m_out = NormalMeanVariance(4.0, 1.0), m_in1 = PointMass(2.0), m_in2 = NormalWeightedMeanPrecision(1.0 / 3.0, 1.0 / 3.0), meta = ReplaceZeroDiagonalEntries(tiny) + ), output = (in1 = PointMass(2.0), in2 = NormalWeightedMeanPrecision(25 / 3, 13 / 3)) ) ]