Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MatrixCorrectionTools.jl #351

Merged
merged 5 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReactiveMP"
uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3"
authors = ["Dmitry Bagaev <d.v.bagaev@tue.nl>", "Albert Podusenko <a.podusenko@tue.nl>", "Bart van Erp <b.v.erp@tue.nl>", "Ismail Senoz <i.senoz@tue.nl>"]
version = "3.10.0"
version = "3.11.0"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -15,6 +15,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MatrixCorrectionTools = "41f81499-25de-46de-b591-c3cfc21e9eaf"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -49,6 +50,7 @@ HCubature = "1.0.0"
LazyArrays = "0.21, 0.22, 1"
LoopVectorization = "0.12"
MacroTools = "0.5"
MatrixCorrectionTools = "1.2.0"
Optim = "1.0.0"
Optimisers = "0.2"
PositiveFactorizations = "0.2"
Expand Down
5 changes: 3 additions & 2 deletions src/ReactiveMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
module ReactiveMP

# List global dependencies here
using TinyHugeNumbers
using TinyHugeNumbers, MatrixCorrectionTools

import MatrixCorrectionTools: AbstractCorrectionStrategy, correction!

# Reexport `tiny` and `huge` from the `TinyHugeNumbers`
export tiny, huge
Expand All @@ -16,7 +18,6 @@ include("score/counting.jl")

include("helpers/algebra/cholesky.jl")
include("helpers/algebra/companion_matrix.jl")
include("helpers/algebra/correction.jl")
include("helpers/algebra/common.jl")
include("helpers/algebra/permutation_matrix.jl")
include("helpers/algebra/standard_basis_vector.jl")
Expand Down
44 changes: 24 additions & 20 deletions src/constraints/specifications/meta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,30 +65,34 @@
function resolve_meta(metaspec, fform, variables)
symfform = as_node_symbol(fform)

var_names = map(name, TupleTools.flatten(variables))
var_refs = map(resolve_variable_proxy, TupleTools.flatten(variables))
var_refs_names = map(r -> r[1], var_refs)

found = nothing

unrolled_foreach(getentries(metaspec)) do fentry
# We iterate over all entries in the meta specification
if functionalform(fentry) === symfform && (all(s -> s ∈ var_names, getnames(fentry)) || all(s -> s ∈ var_refs_names, getnames(fentry)))
if isnothing(found)
# if we find an appropriate meta spec we simply set it
found = fentry
elseif !isnothing(found) && issubset(getnames(fentry), getnames(found)) && issubset(getnames(found), getnames(fentry))
# The error case is the meta specification collision, two sets of names are exactly the same
error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).")
elseif !isnothing(found) && issubset(getnames(fentry), getnames(found))
# If we find another matching meta spec, but it has fewer names in it we simply keep the previous one
nothing
elseif !isnothing(found) && issubset(getnames(found), getnames(fentry))
# If we find another matching meta spec, and it has more names we override the previous one
found = fentry
elseif !isnothing(found) && !issubset(getnames(fentry), getnames(found)) && !issubset(getnames(found), getnames(fentry))
# The error case is the meta specification collision, two sets of names are different and do not include each other
error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).")
if functionalform(fentry) === symfform

Check warning on line 72 in src/constraints/specifications/meta.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/meta.jl#L72

Added line #L72 was not covered by tests
# The `var_names` & `var_refs_names` should be done only if we hit the required entry
# otherwise it would be too error prone, because many nodes cannot properly resolve their `var_names` (e.g. deterministic nodes with more than one input)
# but there might be no meta specification for such nodes, currently the algorithm recompute those for each hit, this can probably be improved
local var_names = map(name, TupleTools.flatten(variables))
local var_refs = map(resolve_variable_proxy, TupleTools.flatten(variables))
local var_refs_names = map(r -> r[1], var_refs)
if (all(s -> s ∈ var_names, getnames(fentry)) || all(s -> s ∈ var_refs_names, getnames(fentry)))
if isnothing(found)

Check warning on line 80 in src/constraints/specifications/meta.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/meta.jl#L76-L80

Added lines #L76 - L80 were not covered by tests
# if we find an appropriate meta spec we simply set it
found = fentry
elseif !isnothing(found) && issubset(getnames(fentry), getnames(found)) && issubset(getnames(found), getnames(fentry))

Check warning on line 83 in src/constraints/specifications/meta.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/meta.jl#L82-L83

Added lines #L82 - L83 were not covered by tests
# The error case is the meta specification collision, two sets of names are exactly the same
error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).")
elseif !isnothing(found) && issubset(getnames(fentry), getnames(found))

Check warning on line 86 in src/constraints/specifications/meta.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/meta.jl#L85-L86

Added lines #L85 - L86 were not covered by tests
# If we find another matching meta spec, but it has fewer names in it we simply keep the previous one
nothing
elseif !isnothing(found) && issubset(getnames(found), getnames(fentry))

Check warning on line 89 in src/constraints/specifications/meta.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/meta.jl#L88-L89

Added lines #L88 - L89 were not covered by tests
# If we find another matching meta spec, and it has more names we override the previous one
found = fentry
elseif !isnothing(found) && !issubset(getnames(fentry), getnames(found)) && !issubset(getnames(found), getnames(fentry))

Check warning on line 92 in src/constraints/specifications/meta.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/meta.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
# The error case is the meta specification collision, two sets of names are different and do not include each other
error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).")

Check warning on line 94 in src/constraints/specifications/meta.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/meta.jl#L94

Added line #L94 was not covered by tests
end
end
end
end
Expand Down
102 changes: 0 additions & 102 deletions src/helpers/algebra/correction.jl

This file was deleted.

4 changes: 2 additions & 2 deletions src/nodes/dot_product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@

@node typeof(dot) Deterministic [out, in1, in2]

# By default dot-product node uses TinyCorrection() strategy for precision matrix on `in1` and `in2` edges to ensure precision is always invertible
default_meta(::typeof(dot)) = TinyCorrection()
# By default dot-product node uses `MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny)` strategy for precision matrix on `in1` and `in2` edges to ensure precision is always invertible
default_meta(::typeof(dot)) = MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny)

Check warning on line 8 in src/nodes/dot_product.jl

View check run for this annotation

Codecov / codecov/patch

src/nodes/dot_product.jl#L8

Added line #L8 was not covered by tests
4 changes: 2 additions & 2 deletions src/nodes/multiplication.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@node typeof(*) Deterministic [out, A, in]

# By default multiplication node uses TinyCorrection() strategy for precision matrix on `in` edge to ensure precision is always invertible
default_meta(::typeof(*)) = TinyCorrection()
# By default multiplication node uses `MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny)` strategy for precision matrix on `in` edge to ensure precision is always invertible
default_meta(::typeof(*)) = MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny)

Check warning on line 5 in src/nodes/multiplication.jl

View check run for this annotation

Codecov / codecov/patch

src/nodes/multiplication.jl#L5

Added line #L5 was not covered by tests
2 changes: 1 addition & 1 deletion src/rules/dot_product/in1.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

@rule typeof(dot)(:in1, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in2::PointMass, meta::AbstractCorrection) = begin
@rule typeof(dot)(:in1, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in2::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return @call_rule typeof(dot)(:in2, Marginalisation) (m_out = m_out, m_in1 = m_in2, meta = meta)
end
2 changes: 1 addition & 1 deletion src/rules/dot_product/in2.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@rule typeof(dot)(:in2, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in1::PointMass, meta::AbstractCorrection) = begin
@rule typeof(dot)(:in2, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in1::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in1)
out_wmean, out_prec = weightedmean_precision(m_out)

Expand Down
4 changes: 2 additions & 2 deletions src/rules/dot_product/marginals.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::AbstractCorrection) = begin
@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin

# Forward message towards `in2` edge
mf_in2 = @call_rule typeof(dot)(:in2, Marginalisation) (m_out = m_out, m_in1 = m_in1, meta = meta)
Expand All @@ -8,7 +8,7 @@
return convert_paramfloattype((in1 = m_in1, in2 = q_in2))
end

@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::AbstractCorrection) = begin
@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
symmetric = @call_marginalrule typeof(dot)(:in1_in2) (m_out = m_out, m_in1 = m_in2, m_in2 = m_in1, meta = meta)
return convert_paramfloattype((in1 = symmetric[:in2], in2 = symmetric[:in1]))
end
4 changes: 2 additions & 2 deletions src/rules/dot_product/out.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

@rule typeof(dot)(:out, Marginalisation) (m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::AbstractCorrection) = begin
@rule typeof(dot)(:out, Marginalisation) (m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return @call_rule typeof(dot)(:out, Marginalisation) (m_in1 = m_in2, m_in2 = m_in1, meta = meta)
end

@rule typeof(dot)(:out, Marginalisation) (m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::AbstractCorrection) = begin
@rule typeof(dot)(:out, Marginalisation) (m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in1)
in2_mean, in2_cov = mean_cov(m_in2)
return NormalMeanVariance(dot(A, in2_mean), dot(A, in2_cov, A))
Expand Down
20 changes: 11 additions & 9 deletions src/rules/multiplication/A.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@

@rule typeof(*)(:A, Marginalisation) (m_out::PointMass, m_in::PointMass, meta::Union{<:AbstractCorrection, Nothing}) = PointMass(mean(m_in) \ mean(m_out))
@rule typeof(*)(:A, Marginalisation) (m_out::PointMass, m_in::PointMass, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = PointMass(mean(m_in) \ mean(m_out))

@rule typeof(*)(:A, Marginalisation) (m_out::GammaDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::GammaDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
return GammaShapeRate(shape(m_out), rate(m_out) * mean(m_in))
end

# if A is a matrix, then the result is multivariate
@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
W = correction!(meta, A' * W_out * A)
Expand All @@ -15,23 +15,23 @@ end

# if A is a vector, then the result is univariate
# this rule links to the special case (AbstractVector * Univariate) for forward (:out) rule
@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
W = correction!(meta, dot(A, W_out, A))
return NormalWeightedMeanPrecision(dot(A, ξ_out), W)
end

# if A is a scalar, then the input is either univariate or multivariate
@rule typeof(*)(:A, Marginalisation) (m_out::F, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrection, Nothing}) where {F <: NormalDistributionsFamily} = begin
@rule typeof(*)(:A, Marginalisation) (m_out::F, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) where {F <: NormalDistributionsFamily} = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
W = correction!(meta, A^2 * W_out)
return convert(promote_variate_type(F, NormalWeightedMeanPrecision), A * ξ_out, W)
end

# specialized versions for mean-covariance parameterization
@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in)
μ_out, Σ_out = mean_cov(m_out)

Expand All @@ -42,7 +42,7 @@ end
return MvNormalWeightedMeanPrecision(tmp * μ_out, W)
end

@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in)
μ_out, Σ_out = mean_cov(m_out)

Expand All @@ -53,14 +53,16 @@ end
return NormalWeightedMeanPrecision(dot(tmp, μ_out), W)
end

@rule typeof(*)(:A, Marginalisation) (m_out::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (
m_out::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}
) = begin
μ_in, var_in = mean_var(m_in)
μ_out, var_out = mean_var(m_out)
log_backwardpass = (x) -> -log(abs(x)) - 0.5 * log(2π * (var_in + var_out / x^2)) - 1 / 2 * (μ_out - x * μ_in)^2 / (var_in * x^2 + var_out)
return ContinuousUnivariateLogPdf(log_backwardpass)
end

@rule typeof(*)(:A, Marginalisation) (m_out::UnivariateDistribution, m_in::UnivariateDistribution, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::UnivariateDistribution, m_in::UnivariateDistribution, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
nsamples = 3000
samples_in = rand(m_in, nsamples)
p = make_inversedist_message(samples_in, m_out)
Expand Down
Loading
Loading