Skip to content

Commit

Permalink
Merge pull request #381 from ReactiveBayes/fullFactorisation_new_name
Browse files Browse the repository at this point in the history
Full factorisation new name
  • Loading branch information
bvdmitri authored Mar 6, 2024
2 parents 6328823 + b38bdfe commit d023367
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/constraints/specifications/constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ const __EmptyConstraints = ConstraintsSpecification((), (;), (;), ConstraintsSpe

__reset_preallocated!(specification::ConstraintsSpecification, size::Int) = __reset_preallocated!(specification.preallocated, size)

function activate!(::Union{UnspecifiedConstraints, MeanField, FullFactorisation}, ::FactorNodesCollection, ::VariablesCollection)
function activate!(::Union{UnspecifiedConstraints, MeanField, BetheFactorisation}, ::FactorNodesCollection, ::VariablesCollection)
return nothing
end

Expand Down
8 changes: 4 additions & 4 deletions src/constraints/specifications/factorisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ resolve_factorisation(::UnspecifiedConstraints, allvariables, fform, variables)
resolve_factorisation(::UnspecifiedConstraints, any, allvariables, fform, variables) = resolve_factorisation(__EmptyConstraints, allvariables, fform, variables)

# Preoptimised dispatch rule for unspecified constraints and a deterministic node with any number of inputs
resolve_factorisation(::UnspecifiedConstraints, ::Deterministic, allvariables, fform, variables) = FullFactorisation()
resolve_factorisation(::UnspecifiedConstraints, ::Deterministic, allvariables, fform, variables) = BetheFactorisation()

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 2 inputs, random variable & constant variable
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: RandomVariable} = ((1, 2),)
Expand Down Expand Up @@ -251,13 +251,13 @@ function resolve_factorisation(constraints, allvariables, fform, variables)
return resolve_factorisation(sdtype(fform), constraints, allvariables, fform, variables)
end

# Deterministic nodes always have `FullFactorisation` constraint (by default)
# Deterministic nodes always have `BetheFactorisation` constraint (by default)
function resolve_factorisation(::Deterministic, constraints, allvariables, fform, variables)
return FullFactorisation()
return BetheFactorisation()
end

# We simply return `constraints` if we get global factorisation constraints
function resolve_factorisation(::Stochastic, constraints::Union{MeanField, FullFactorisation}, allvariables, fform, variables)
function resolve_factorisation(::Stochastic, constraints::Union{MeanField, BetheFactorisation}, allvariables, fform, variables)
return constraints
end

Expand Down
4 changes: 2 additions & 2 deletions src/constraints/specifications/form.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ function resolve_marginal_messages_form_prod(constraints, name)
end

# Preoptimised dispatch rule for generic global factorisation constraints
resolve_marginal_form_prod(::Union{MeanField, FullFactorisation}, name) = (nothing, nothing)
resolve_messages_form_prod(::Union{MeanField, FullFactorisation}, name) = (nothing, nothing)
resolve_marginal_form_prod(::Union{MeanField, BetheFactorisation}, name) = (nothing, nothing)
resolve_messages_form_prod(::Union{MeanField, BetheFactorisation}, name) = (nothing, nothing)

resolve_marginal_form_prod(constraints, name) = resolve_form_prod(constraints, constraints.marginalsform, name)
resolve_messages_form_prod(constraints, name) = resolve_form_prod(constraints, constraints.messagesform, name)
Expand Down
35 changes: 19 additions & 16 deletions src/node.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export Deterministic, Stochastic, isdeterministic, isstochastic, sdtype
export MeanField, FullFactorisation, Marginalisation, MomentMatching
export MeanField, FullFactorisation, BetheFactorisation, Marginalisation, MomentMatching
export functionalform, interfaces, factorisation, localmarginals, localmarginalnames, metadata
export FactorNodesCollection, getnodes, getnode_ids
export make_node, FactorNodeCreationOptions
Expand Down Expand Up @@ -125,25 +125,28 @@ as_node_symbol(fn::F) where {F <: Function} = Symbol(fn)
Generic factorisation constraint used to specify a mean-field factorisation for recognition distribution `q`.
See also: [`FullFactorisation`](@ref)
See also: [`BetheFactorisation`](@ref)
"""
struct MeanField end

"""
FullFactorisation
BetheFactorisation
Generic factorisation constraint used to specify a full factorisation for recognition distribution `q`.
Generic factorisation constraint used to specify the Bethe factorisation for recognition distribution `q`.
See also: [`MeanField`](@ref)
"""
struct FullFactorisation end
struct BetheFactorisation end

# Alias for `BetheFactorisation` to deprecate `FullFactorisation`.
Base.@deprecate_binding FullFactorisation BetheFactorisation

"""
collect_factorisation(nodetype, factorisation)
This function converts given factorisation to a correct internal factorisation representation for a given node.
See also: [`MeanField`](@ref), [`FullFactorisation`](@ref)
See also: [`MeanField`](@ref), [`BetheFactorisation`](@ref)
"""
function collect_factorisation end

Expand Down Expand Up @@ -1119,22 +1122,22 @@ macro node(fformtype, sdtype, interfaces_list)

# By default every argument passed to a factorisation option of the node is transformed by
# `collect_factorisation` function to have a tuple like structure.
# The default recipe is simple: for stochastic nodes we convert `FullFactorisation` and `MeanField` objects
# to their tuple of indices equivalents. For deterministic nodes any factorisation is replaced by a FullFactorisation equivalent
# The default recipe is simple: for stochastic nodes we convert `BetheFactorisation` and `MeanField` objects
# to their tuple of indices equivalents. For deterministic nodes any factorisation is replaced by a BetheFactorisation equivalent
factorisation_collectors = if sdtype === :Stochastic
quote
ReactiveMP.collect_factorisation(::$fuppertype, ::Nothing) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, factorisation::Tuple) = factorisation
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.FullFactorisation) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.MeanField) = $names_splitted_indices
ReactiveMP.collect_factorisation(::$fuppertype, ::Nothing) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, factorisation::Tuple) = factorisation
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.BetheFactorisation) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.MeanField) = $names_splitted_indices
end

elseif sdtype === :Deterministic
quote
ReactiveMP.collect_factorisation(::$fuppertype, ::Nothing) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, factorisation::Tuple) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.FullFactorisation) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.MeanField) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, ::Nothing) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, factorisation::Tuple) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.BetheFactorisation) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.MeanField) = ($names_indices,)
end
else
error("Unreachable in @node macro.")
Expand Down
24 changes: 12 additions & 12 deletions src/nodes/mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ReactiveMP.as_node_symbol(::Type{<:Mixture}) = :Mixture
#
# Base.show

const MixtureNodeFactorisationSupport = Union{FullFactorisation}
const MixtureNodeFactorisationSupport = Union{BetheFactorisation}

struct MixtureNode{N, F <: MixtureNodeFactorisationSupport, M, P} <: AbstractFactorNode
factorisation::F
Expand Down Expand Up @@ -63,7 +63,7 @@ struct MixtureNodeFunctionalDependencies <: AbstractNodeFunctionalDependenciesPi

default_functional_dependencies_pipeline(::Type{<:Mixture}) = MixtureNodeFunctionalDependencies()

function functional_dependencies(::MixtureNodeFunctionalDependencies, factornode::MixtureNode{N, F}, iindex::Int) where {N, F <: FullFactorisation}
function functional_dependencies(::MixtureNodeFunctionalDependencies, factornode::MixtureNode{N, F}, iindex::Int) where {N, F <: BetheFactorisation}
message_dependencies = if iindex === 1
# output depends on:
(factornode.switch, factornode.inputs)
Expand All @@ -83,7 +83,7 @@ function functional_dependencies(::MixtureNodeFunctionalDependencies, factornode
end

# function for using hard switching
function functional_dependencies(::RequireMarginalFunctionalDependencies, factornode::MixtureNode{N, F}, iindex::Int) where {N, F <: FullFactorisation}
function functional_dependencies(::RequireMarginalFunctionalDependencies, factornode::MixtureNode{N, F}, iindex::Int) where {N, F <: BetheFactorisation}
message_dependencies = if iindex === 1
# output depends on:
(factornode.inputs,)
Expand Down Expand Up @@ -116,7 +116,7 @@ end
# create message observable for output or Mixture edge without pipeline constraints (the message towards the inputs are fine by default behaviour, i.e. they depend only on switch and output and no longer on all other inputs)
function get_messages_observable(
factornode::MixtureNode{N, F, Nothing, FactorNodePipeline{P, EmptyPipelineStage}}, messages::Tuple{NodeInterface, NTuple{N, IndexedNodeInterface}}
) where {N, F <: FullFactorisation, P <: MixtureNodeFunctionalDependencies}
) where {N, F <: BetheFactorisation, P <: MixtureNodeFunctionalDependencies}
output_or_switch_interface = messages[1]
inputsinterfaces = messages[2]

Expand All @@ -130,7 +130,7 @@ end
# create an observable that is used to compute the switch with pipeline constraints
function get_messages_observable(
factornode::MixtureNode{N, F, Nothing, FactorNodePipeline{P, EmptyPipelineStage}}, messages::Tuple{NodeInterface, NTuple{N, IndexedNodeInterface}}
) where {N, F <: FullFactorisation, P <: RequireMarginalFunctionalDependencies}
) where {N, F <: BetheFactorisation, P <: RequireMarginalFunctionalDependencies}
switchinterface = messages[1]
inputsinterfaces = messages[2]

Expand All @@ -144,7 +144,7 @@ end
# create an observable that is used to compute the output with pipeline constraints
function get_messages_observable(
factornode::MixtureNode{N, F, Nothing, FactorNodePipeline{P, EmptyPipelineStage}}, messages::Tuple{NTuple{N, IndexedNodeInterface}}
) where {N, F <: FullFactorisation, P <: RequireMarginalFunctionalDependencies}
) where {N, F <: BetheFactorisation, P <: RequireMarginalFunctionalDependencies}
inputsinterfaces = messages[1]

msgs_names = Val{(name(inputsinterfaces[1]),)}()
Expand All @@ -155,7 +155,7 @@ end
# create an observable that is used to compute the input with pipeline constraints
function get_messages_observable(
factornode::MixtureNode{N, F, Nothing, FactorNodePipeline{P, EmptyPipelineStage}}, messages::Tuple{NodeInterface}
) where {N, F <: FullFactorisation, P <: RequireMarginalFunctionalDependencies}
) where {N, F <: BetheFactorisation, P <: RequireMarginalFunctionalDependencies}
outputinterface = messages[1]

msgs_names = Val{(name(outputinterface),)}()
Expand All @@ -165,7 +165,7 @@ end

function get_marginals_observable(
factornode::MixtureNode{N, F, Nothing, FactorNodePipeline{P, EmptyPipelineStage}}, marginals::Tuple{NodeInterface}
) where {N, F <: FullFactorisation, P <: RequireMarginalFunctionalDependencies}
) where {N, F <: BetheFactorisation, P <: RequireMarginalFunctionalDependencies}
switchinterface = marginals[1]

marginal_names = Val{(name(switchinterface),)}()
Expand All @@ -184,18 +184,18 @@ as_node_functional_form(::Type{<:Mixture}) = ValidNodeFunctionalForm()

sdtype(::Type{<:Mixture}) = Deterministic()

collect_factorisation(::Type{<:Mixture{N}}, factorisation::FullFactorisation) where {N} = factorisation
collect_factorisation(::Type{<:Mixture{N}}, factorisation::BetheFactorisation) where {N} = factorisation
collect_factorisation(::Type{<:Mixture{N}}, factorisation::Any) where {N} = __mixture_incompatible_factorisation_error()

function collect_factorisation(::Type{<:Mixture{N}}, factorisation::NTuple{R, Tuple{<:Integer}}) where {N, R}
# inputs + switch + out, equivalent to FullFactorisation
return (R === N + 2) ? FullFactorisation() : __mixture_incompatible_factorisation_error()
# inputs + switch + out, equivalent to BetheFactorisation
return (R === N + 2) ? BetheFactorisation() : __mixture_incompatible_factorisation_error()
end

__mixture_incompatible_factorisation_error() =
error("`MixtureNode` supports only following global factorisations: [ $(MixtureNodeFactorisationSupport) ] or manually set to equivalent via constraints")

function ReactiveMP.make_node(::Type{<:Mixture{N}}, factorisation::F = FullFactorisation(), meta::M = nothing, pipeline::P = nothing) where {N, F, M, P}
function ReactiveMP.make_node(::Type{<:Mixture{N}}, factorisation::F = BetheFactorisation(), meta::M = nothing, pipeline::P = nothing) where {N, F, M, P}
@assert typeof(factorisation) <: MixtureNodeFactorisationSupport "`MixtureNode` supports only following factorisations: [ $(MixtureNodeFactorisationSupport) ]"
out = NodeInterface(:out, Marginalisation())
switch = NodeInterface(:switch, Marginalisation())
Expand Down
4 changes: 3 additions & 1 deletion test/node_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

@test ReactiveMP.collect_factorisation(CustomStochasticNode, MeanField()) === ((1,), (2,), (3,), (4,))
@test ReactiveMP.collect_factorisation(CustomStochasticNode, FullFactorisation()) === ((1, 2, 3, 4),)
@test ReactiveMP.collect_factorisation(CustomStochasticNode, BetheFactorisation()) === ((1, 2, 3, 4),)

@test sdtype(CustomStochasticNode) === Stochastic()

Expand Down Expand Up @@ -90,6 +91,7 @@

@test ReactiveMP.collect_factorisation(CustomDeterministicNode, MeanField()) === ((1, 2, 3, 4),)
@test ReactiveMP.collect_factorisation(CustomDeterministicNode, FullFactorisation()) === ((1, 2, 3, 4),)
@test ReactiveMP.collect_factorisation(CustomDeterministicNode, BetheFactorisation()) === ((1, 2, 3, 4),)

@test sdtype(CustomDeterministicNode) === Deterministic()

Expand Down Expand Up @@ -119,7 +121,7 @@

for a in (datavar(:a, Float64), constvar(:a, 1.0)), b in (randomvar(:b),), c in (randomvar(:c),)
@test_logs (:warn, r".*replace `q\(a, b, c\)` with `q\(a\)q\(\.\.\.\)`.*") make_node(
DummyNodeCheckFactorisationWarning, FactorNodeCreationOptions(FullFactorisation(), nothing, nothing), a, b, c
DummyNodeCheckFactorisationWarning, FactorNodeCreationOptions(BetheFactorisation(), nothing, nothing), a, b, c
)
@test_logs (:warn, r".*replace `q\(a, b, c\)` with `q\(a\)q\(\.\.\.\)`.*") make_node(
DummyNodeCheckFactorisationWarning, FactorNodeCreationOptions(((1, 2, 3),), nothing, nothing), a, b, c
Expand Down

0 comments on commit d023367

Please sign in to comment.