Skip to content

Commit

Permalink
Merge pull request #365 from biaslab/dev-fix-rxinfer-30
Browse files Browse the repository at this point in the history
Allow const/data inputs for the non-linear deterministic nodes
  • Loading branch information
bvdmitri authored Nov 28, 2023
2 parents 178d16c + c16d0be commit e1bcebf
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 17 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
ExponentialFamily = "62312e5e-252a-4322-ace9-a5f4bf9b357b"
FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FixedArguments = "4130a065-6d82-41fe-881e-7a5c65156f7d"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Expand Down Expand Up @@ -44,13 +45,14 @@ ReactiveMPRequiresExt = "Requires"
[compat]
BayesBase = "1.1.0"
DataStructures = "0.17, 0.18"
Distributions = "0.24, 0.25"
DiffResults = "1.1.0"
Distributions = "0.24, 0.25"
DomainIntegrals = "0.3.2, 0.4"
DomainSets = "0.5.2, 0.6, 0.7"
ExponentialFamily = "1.2.0"
FastCholesky = "1.3.0"
FastGaussQuadrature = "0.4, 0.5"
FixedArguments = "0.1"
ForwardDiff = "0.10"
HCubature = "1.0.0"
LazyArrays = "0.21, 0.22, 1"
Expand Down
78 changes: 64 additions & 14 deletions src/nodes/delta/delta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,27 @@ getinverse(meta::DeltaMeta, k::Int) = meta.inverse[k]
import Base: map

struct DeltaFn{F} end
struct DeltaFnNode{F, N, L, M} <: AbstractFactorNode
struct DeltaFnNode{F, P, N, S, L, M} <: AbstractFactorNode
fn::F

proxy::P
out::NodeInterface
ins::NTuple{N, IndexedNodeInterface}

statics :: S
localmarginals :: L
metadata :: M
end

as_node_symbol(::Type{<:DeltaFn{F}}) where {F} = Symbol(replace(string(nameof(F)), "#" => ""))

functionalform(factornode::DeltaFnNode{F}) where {F} = DeltaFn{F}
sdtype(factornode::DeltaFnNode) = Deterministic()
interfaces(factornode::DeltaFnNode) = (factornode.out, factornode.ins...)
factorisation(factornode::DeltaFnNode{F, N}) where {F, N} = ntuple(identity, N + 1)
localmarginals(factornode::DeltaFnNode) = factornode.localmarginals.marginals
localmarginalnames(factornode::DeltaFnNode) = map(name, localmarginals(factornode))
metadata(factornode::DeltaFnNode) = factornode.metadata
functionalform(factornode::DeltaFnNode{F}) where {F} = DeltaFn{F}
sdtype(factornode::DeltaFnNode) = Deterministic()
interfaces(factornode::DeltaFnNode) = (factornode.out, factornode.ins...)
factorisation(factornode::DeltaFnNode{F}) where {F} = ntuple(identity, length(factornode.ins) + 1)
localmarginals(factornode::DeltaFnNode) = factornode.localmarginals.marginals
localmarginalnames(factornode::DeltaFnNode) = map(name, localmarginals(factornode))
metadata(factornode::DeltaFnNode) = factornode.metadata

collect_meta(::Type{D}, something::Nothing) where {D <: DeltaFn} = error(
"Delta node `$(as_node_symbol(D))` requires meta specification with the `where { meta = ... }` in the `@model` macro or with the separate `@meta` specification. See documentation for the `DeltaMeta`."
Expand All @@ -63,7 +65,7 @@ function nodefunction(factornode::DeltaFnNode)
end
end

nodefunction(factornode::DeltaFnNode, ::Val{:out}) = factornode.fn
nodefunction(factornode::DeltaFnNode, ::Val{:out}) = factornode.proxy
nodefunction(factornode::DeltaFnNode, ::Val{:in}) = getinverse(metadata(factornode))
nodefunction(factornode::DeltaFnNode, ::Val{:in}, k::Integer) = getinverse(metadata(factornode), k)

Expand All @@ -86,7 +88,7 @@ call_rule_is_node_required(::Type{<:DeltaFn}) = CallRuleNodeRequired()
function call_rule_make_node(::CallRuleNodeRequired, fformtype::Type{<:DeltaFn}, nodetype::F, meta::DeltaMeta) where {F}
# This node is not initialized properly, but we do not expect rules to access internal uninitialized fields.
# Doing so will most likely throw an error
return DeltaFnNode(nodetype, NodeInterface(:out, Marginalisation()), (), nothing, collect_meta(DeltaFn{F}, meta))
return DeltaFnNode(nodetype, nodetype, NodeInterface(:out, Marginalisation()), (), (), nothing, collect_meta(DeltaFn{F}, meta))
end

function interfaceindex(factornode::DeltaFnNode, iname::Symbol)
Expand All @@ -100,20 +102,37 @@ function interfaceindex(factornode::DeltaFnNode, iname::Symbol)
end
end

function __make_delta_fn_node(fn::F, options::FactorNodeCreationOptions, out::AbstractVariable, ins::NTuple{N, <:AbstractVariable}) where {F <: Function, N}
import FixedArguments
import FixedArguments: FixedArgument, FixedPosition

function __make_delta_fn_node(fn::F, options::FactorNodeCreationOptions, out::AbstractVariable, ins::Tuple) where {F <: Function}
out_interface = NodeInterface(:out, Marginalisation())
ins_interface = ntuple(i -> IndexedNodeInterface(i, NodeInterface(:in, Marginalisation())), N)

# The inputs for the deterministic function are being splitted into two groups:
# 1. Random variables and 2. Const/Data variables (static inputs)
randoms, statics = __split_static_inputs(ins)

# We create interfaces only for random variables
# The static variables are being passed to the `FixedArguments.fix` function
ins_interface = ntuple(i -> IndexedNodeInterface(i, NodeInterface(:in, Marginalisation())), length(randoms))

out_index = getlastindex(out)
connectvariable!(out_interface, out, out_index)
setmessagein!(out, out_index, messageout(out_interface))

foreach(zip(ins_interface, ins)) do (in_interface, in_var)
foreach(zip(ins_interface, randoms)) do (in_interface, in_var)
in_index = getlastindex(in_var)
connectvariable!(in_interface, in_var, in_index)
setmessagein!(in_var, in_index, messageout(in_interface))
end

foreach(statics) do static
setused!(FixedArguments.value(static))
end

# The proxy is the actual node function, but with the static inputs already fixed at their respective position
# We use the `__unpack_latest_static` function to get the latest value of the static variables
proxy = FixedArguments.fix(fn, __unpack_latest_static, statics)
localmarginals = FactorNodeLocalMarginals((FactorNodeLocalMarginal(1, 1, :out), FactorNodeLocalMarginal(2, 2, :ins)))
meta = collect_meta(DeltaFn{F}, metadata(options))
pipeline = getpipeline(options)
Expand All @@ -122,9 +141,40 @@ function __make_delta_fn_node(fn::F, options::FactorNodeCreationOptions, out::Ab
@warn "Delta node does not support the `pipeline` option."
end

return DeltaFnNode(fn, out_interface, ins_interface, localmarginals, meta)
if !isnothing(getinverse(meta)) && !isempty(statics)
error("The inverse function specification is not supported for the Delta node, which is connected to datavar/constvar edges.")
end

return DeltaFnNode(fn, proxy, out_interface, ins_interface, statics, localmarginals, meta)
end

# This function takes the inputs of the deterministic nodes and sorts them into two
# groups: the first group is of type `RandomVariable` and the second group is of type `ConstVariable/DataVariable`
__split_static_inputs(ins::Tuple) = __split_static_inputs(Val(1), (), (), ins)

# Stop if the `remaining` tuple is empty
__split_static_inputs(::Val{N}, randoms, statics, remaining::Tuple{}) where {N} = (randoms, statics)
# Split the `remaining` tuple into head (current) and tail (remaining)
__split_static_inputs(::Val{N}, randoms, statics, remaining::Tuple) where {N} = __split_static_inputs(Val(N), randoms, statics, first(remaining), Base.tail(remaining))

# If the current input is a random variable, we add it to the `randoms` tuple
function __split_static_inputs(::Val{N}, randoms, statics, current::RandomVariable, remaining::Tuple) where {N}
return __split_static_inputs(Val(N + 1), (randoms..., current), statics, remaining)
end

# If the current input is a const/data variable, we add it to the `statics` tuple with its respective position
function __split_static_inputs(::Val{N}, randoms, statics, current::Union{ConstVariable, DataVariable}, remaining::Tuple) where {N}
return __split_static_inputs(Val(N + 1), randoms, (statics..., FixedArgument(FixedPosition(N), current)), remaining)
end

# This function is used to unpack the latest value of the static variables
# For constvar we just return the value
# For datavar we get the latest value from the data stream
__unpack_latest_static(_, constvar::ConstVariable) = getconst(constvar)
__unpack_latest_static(_, datavar::DataVariable) = BayesBase.getpointmass(getdata(Rocket.getrecent(messageout(datavar, 1))))

##

function make_node(fform::F, options::FactorNodeCreationOptions, args::Vararg{<:AbstractVariable}) where {F <: Function}
return __make_delta_fn_node(fform, options, args[1], args[2:end])
end
Expand Down
1 change: 1 addition & 0 deletions src/nodes/delta/layouts/cvi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ function deltafn_apply_layout(::CVIApproximationDeltaFnRuleLayout, ::Val{:m_out}
(dependencies) -> VariationalMessage(dependencies[1], dependencies[2], messagemap)
end

vmessageout = with_statics(factornode, vmessageout)
vmessageout = vmessageout |> map(AbstractMessage, mapping)
vmessageout = apply_pipeline_stage(pipeline_stages, factornode, vtag, vmessageout)
vmessageout = vmessageout |> schedule_on(scheduler)
Expand Down
26 changes: 24 additions & 2 deletions src/nodes/delta/layouts/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ See also: [`ReactiveMP.DeltaFnDefaultKnownInverseRuleLayout`](@ref)
"""
struct DeltaFnDefaultRuleLayout <: AbstractDeltaNodeDependenciesLayout end

import FixedArguments

function with_statics(factornode::DeltaFnNode, stream)
return with_statics(factornode, factornode.statics, stream)
end

function with_statics(factornode::DeltaFnNode, statics::Tuple, stream::T) where {T}
# We wait for the statics to be available, but ignore their actual values
# They are being injected indirectly with the `fix` function upon node creation
statics = map(static -> messageout(static, 1), FixedArguments.value.(factornode.statics))
return combineLatest((stream, combineLatest(statics, PushNew()))) |> map(eltype(T), first)
end

function with_statics(factornode::DeltaFnNode, statics::Tuple{}, stream::T) where {T}
# There is no need to touch the original stream if there are no statics
return stream
end

# This function declares how to compute `q_out` locally around `DeltaFn`
function deltafn_apply_layout(::DeltaFnDefaultRuleLayout, ::Val{:q_out}, factornode::DeltaFnNode, pipeline_stages, scheduler, addons)
let out = factornode.out, localmarginal = factornode.localmarginals.marginals[1]
Expand Down Expand Up @@ -44,7 +62,7 @@ function deltafn_apply_layout(::DeltaFnDefaultRuleLayout, ::Val{:q_ins}, factorn
meta = metadata(factornode)

mapping = MarginalMapping(fform, vtag, msgs_names, marginal_names, meta, factornode)
marginalout = combineLatest((msgs_observable, marginals_observable), PushNew()) |> discontinue() |> map(Marginal, mapping)
marginalout = with_statics(factornode, combineLatest((msgs_observable, marginals_observable), PushNew())) |> discontinue() |> map(Marginal, mapping)

connect!(cmarginal, marginalout) # MarginalObservable has RecentSubject by default, there is no need to share_recent() here
end
Expand Down Expand Up @@ -72,6 +90,7 @@ function deltafn_apply_layout(::DeltaFnDefaultRuleLayout, ::Val{:m_out}, factorn
(dependencies) -> VariationalMessage(dependencies[1], dependencies[2], messagemap)
end

vmessageout = with_statics(factornode, vmessageout)
vmessageout = vmessageout |> map(AbstractMessage, mapping)
vmessageout = apply_pipeline_stage(pipeline_stages, factornode, vtag, vmessageout)
vmessageout = vmessageout |> schedule_on(scheduler)
Expand Down Expand Up @@ -101,6 +120,7 @@ function deltafn_apply_layout(::DeltaFnDefaultRuleLayout, ::Val{:m_in}, factorno
(dependencies) -> VariationalMessage(dependencies[1], dependencies[2], messagemap)
end

vmessageout = with_statics(factornode, vmessageout)
vmessageout = vmessageout |> map(AbstractMessage, mapping)
vmessageout = apply_pipeline_stage(pipeline_stages, factornode, vtag, vmessageout)
vmessageout = vmessageout |> schedule_on(scheduler)
Expand Down Expand Up @@ -140,7 +160,8 @@ function deltafn_apply_layout(::DeltaFnDefaultKnownInverseRuleLayout, ::Val{:m_o
end

# This function declares how to compute `m_in`
function deltafn_apply_layout(::DeltaFnDefaultKnownInverseRuleLayout, ::Val{:m_in}, factornode::DeltaFnNode{F, N}, pipeline_stages, scheduler, addons) where {F, N}
function deltafn_apply_layout(::DeltaFnDefaultKnownInverseRuleLayout, ::Val{:m_in}, factornode::DeltaFnNode{F}, pipeline_stages, scheduler, addons) where {F}
N = length(factornode.ins)
# For each outbound message from `in_k` edge we need an inbound messages from all OTHER! `in_*` edges and inbound message on `m_out`
foreach(enumerate(factornode.ins)) do (index, interface)

Expand Down Expand Up @@ -170,6 +191,7 @@ function deltafn_apply_layout(::DeltaFnDefaultKnownInverseRuleLayout, ::Val{:m_i
(dependencies) -> VariationalMessage(dependencies[1], dependencies[2], messagemap)
end

vmessageout = with_statics(factornode, vmessageout)
vmessageout = vmessageout |> map(AbstractMessage, mapping)
vmessageout = apply_pipeline_stage(pipeline_stages, factornode, vtag, vmessageout)
vmessageout = vmessageout |> schedule_on(scheduler)
Expand Down
1 change: 1 addition & 0 deletions src/variables/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ degree(constvar::ConstVariable) = nconnected(constvar)
name(constvar::ConstVariable) = constvar.name
proxy_variables(constvar::ConstVariable) = nothing
collection_type(constvar::ConstVariable) = constvar.collection_type
setused!(constvar::ConstVariable) = nothing

isproxy(::ConstVariable) = false

Expand Down
1 change: 1 addition & 0 deletions src/variables/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ proxy_variables(datavar::DataVariable) = nothing # not related to isproxy
collection_type(datavar::DataVariable) = datavar.collection_type
isconnected(datavar::DataVariable) = datavar.nconnected !== 0
nconnected(datavar::DataVariable) = datavar.nconnected
setused!(datavar::DataVariable) = datavar.isused = true

isproxy(datavar::DataVariable) = datavar.isproxy
isused(datavar::DataVariable) = datavar.isused
Expand Down
65 changes: 65 additions & 0 deletions test/nodes/delta/test_delta.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
module DeltaNodeTest

using Test, ReactiveMP, Random

@testset "DeltaNode" begin
@testset "Creation with static inputs (simple case) #1" begin
import ReactiveMP: nodefunction, FactorNodeCreationOptions

foo(x, y, z) = x * y + z

out = randomvar(:out)

x = randomvar(:x)
y = datavar(:y, Float64)
z = constvar(:z, 3.0)

node = make_node(foo, FactorNodeCreationOptions(nothing, Linearization(), nothing), out, x, y, z)

update!(y, 2.0)

for xval in rand(10)
@test nodefunction(node, Val(:out))(xval) === foo(xval, 2.0, 3.0)
@test nodefunction(node)(foo(xval, 2.0, 3.0), xval) === 1
@test nodefunction(node)(foo(xval, 2.0, 3.0) + 1.0, xval) === 0
end
end

@testset "Creation with static inputs (all permutations) #2" begin
import ReactiveMP: nodefunction, FactorNodeCreationOptions

foo1(x, y, z) = x * y + z
foo2(x, y, z) = x / y - z
foo3(x, y, z) = x - y * z

out = randomvar(:out)
opt = FactorNodeCreationOptions(nothing, Linearization(), nothing)

for vals in [rand(Float64, 3) for _ in 1:10], foo in (foo1, foo2, foo3)

# In this test we attempt to create a lot of possible combinations
# of random, data and const inputs to the delta node
create_interfaces(i) = (randomvar(:x), datavar(:y, Float64), constvar(:z, vals[i]))

for x in create_interfaces(1), y in create_interfaces(2), z in create_interfaces(3)
interfaces = [x, y, z]

rpos = findall(i -> i isa RandomVariable, interfaces)
node = make_node(foo, opt, out, interfaces...)

# data variable inputs require an actual update
foreach(enumerate(interfaces)) do (i, interface)
if interface isa DataVariable
update!(interface, vals[i])
end
end

@test nodefunction(node, Val(:out))(vals[rpos]...) === foo(vals...)
@test nodefunction(node)(foo(vals...), vals[rpos]...) === 1
@test nodefunction(node)(foo(vals...) + 1, vals[rpos]...) === 0
end
end
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ end
addtests(testrunner, "pipeline/test_logger.jl")

addtests(testrunner, "test_node.jl")
addtests(testrunner, "nodes/delta/test_delta.jl")
addtests(testrunner, "nodes/flow/test_flow.jl")
addtests(testrunner, "nodes/test_addition.jl")
addtests(testrunner, "nodes/test_bifm.jl")
Expand Down

0 comments on commit e1bcebf

Please sign in to comment.