From 5f560195e6465d021398a28dfe0c071a17fa318c Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 24 Nov 2023 12:25:44 +0100 Subject: [PATCH 1/9] add support for const/data inputs in nonlinear nodes --- Project.toml | 4 ++- src/nodes/delta/delta.jl | 53 ++++++++++++++++++++++++++---- src/nodes/delta/layouts/cvi.jl | 1 + src/nodes/delta/layouts/default.jl | 14 +++++++- src/variables/constant.jl | 1 + src/variables/data.jl | 1 + 6 files changed, 65 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index c44786979..ef78934e2 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" ExponentialFamily = "62312e5e-252a-4322-ace9-a5f4bf9b357b" FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" +FixedArguments = "4130a065-6d82-41fe-881e-7a5c65156f7d" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" @@ -44,13 +45,14 @@ ReactiveMPRequiresExt = "Requires" [compat] BayesBase = "1.1.0" DataStructures = "0.17, 0.18" -Distributions = "0.24, 0.25" DiffResults = "1.1.0" +Distributions = "0.24, 0.25" DomainIntegrals = "0.3.2, 0.4" DomainSets = "0.5.2, 0.6, 0.7" ExponentialFamily = "1.2.0" FastCholesky = "1.3.0" FastGaussQuadrature = "0.4, 0.5" +FixedArguments = "0.1" ForwardDiff = "0.10" HCubature = "1.0.0" LazyArrays = "0.21, 0.22, 1" diff --git a/src/nodes/delta/delta.jl b/src/nodes/delta/delta.jl index a3991686d..99225672a 100644 --- a/src/nodes/delta/delta.jl +++ b/src/nodes/delta/delta.jl @@ -29,12 +29,14 @@ getinverse(meta::DeltaMeta, k::Int) = meta.inverse[k] import Base: map struct DeltaFn{F} end -struct DeltaFnNode{F, N, L, M} <: AbstractFactorNode +struct DeltaFnNode{F, P, N, S, L, M} <: AbstractFactorNode fn::F + proxy::P out::NodeInterface ins::NTuple{N, IndexedNodeInterface} - + + statics :: S localmarginals :: L metadata :: M end @@ -63,7 +65,7 @@ function nodefunction(factornode::DeltaFnNode) end end -nodefunction(factornode::DeltaFnNode, ::Val{:out}) = factornode.fn +nodefunction(factornode::DeltaFnNode, ::Val{:out}) = factornode.proxy nodefunction(factornode::DeltaFnNode, ::Val{:in}) = getinverse(metadata(factornode)) nodefunction(factornode::DeltaFnNode, ::Val{:in}, k::Integer) = getinverse(metadata(factornode), k) @@ -100,20 +102,37 @@ function interfaceindex(factornode::DeltaFnNode, iname::Symbol) end end -function __make_delta_fn_node(fn::F, options::FactorNodeCreationOptions, out::AbstractVariable, ins::NTuple{N, <:AbstractVariable}) where {F <: Function, N} +import FixedArguments +import FixedArguments: FixedArgument, FixedPosition + +function __make_delta_fn_node(fn::F, options::FactorNodeCreationOptions, out::AbstractVariable, ins::Tuple) where {F <: Function} out_interface = NodeInterface(:out, Marginalisation()) - ins_interface = ntuple(i -> IndexedNodeInterface(i, NodeInterface(:in, Marginalisation())), N) + + # The inputs for the deterministic function are being splitted into two groups: + # 1. Random variables and 2. Const/Data variables (static inputs) + randoms, statics = __split_static_inputs(ins) + + # We create interfaces only for random variables + # The static variables are being passed to the `FixedArguments.fix` function + ins_interface = ntuple(i -> IndexedNodeInterface(i, NodeInterface(:in, Marginalisation())), length(randoms)) out_index = getlastindex(out) connectvariable!(out_interface, out, out_index) setmessagein!(out, out_index, messageout(out_interface)) - foreach(zip(ins_interface, ins)) do (in_interface, in_var) + foreach(zip(ins_interface, randoms)) do (in_interface, in_var) in_index = getlastindex(in_var) connectvariable!(in_interface, in_var, in_index) setmessagein!(in_var, in_index, messageout(in_interface)) end + foreach(statics) do static + setused!(FixedArguments.value(static)) + end + + # The proxy is the actual node function, but with the static inputs already fixed at their respective position + # We use the `__unpack_latest_static` function to get the latest value of the static variables + proxy = FixedArguments.fix(fn, __unpack_latest_static, statics) localmarginals = FactorNodeLocalMarginals((FactorNodeLocalMarginal(1, 1, :out), FactorNodeLocalMarginal(2, 2, :ins))) meta = collect_meta(DeltaFn{F}, metadata(options)) pipeline = getpipeline(options) @@ -122,9 +141,29 @@ function __make_delta_fn_node(fn::F, options::FactorNodeCreationOptions, out::Ab @warn "Delta node does not support the `pipeline` option." end - return DeltaFnNode(fn, out_interface, ins_interface, localmarginals, meta) + return DeltaFnNode(fn, proxy, out_interface, ins_interface, statics, localmarginals, meta) +end + +# This function takes the inputs of the deterministic nodes and sorts them into two +# groups: the first group is of type `RandomVariable` and the second group is of type `ConstVariable/DataVariable` +__split_static_inputs(ins::Tuple) = __split_static_inputs(Val(1), (), (), ins) + +__split_static_inputs(::Val{N}, randoms, statics, ins::Tuple{}) where {N} = (randoms, statics) +__split_static_inputs(::Val{N}, randoms, statics, ins::Tuple) where {N} = __split_static_inputs(Val(N), randoms, statics, first(ins), Base.tail(ins)) + +function __split_static_inputs(::Val{N}, randoms, statics, current::RandomVariable, remaining::Tuple) where {N} + return __split_static_inputs(Val(N + 1), (randoms..., current), statics, remaining) end +function __split_static_inputs(::Val{N}, randoms, statics, current::Union{ConstVariable, DataVariable}, remaining::Tuple) where {N} + return __split_static_inputs(Val(N + 1), randoms, (statics..., FixedArgument(FixedPosition(N), current)), remaining) +end + +__unpack_latest_static(_, constvar::ConstVariable) = getconst(constvar) +__unpack_latest_static(_, datavar::DataVariable) = BayesBase.getpointmass(getdata(Rocket.getrecent(messageout(datavar, 1)))) + +## + function make_node(fform::F, options::FactorNodeCreationOptions, args::Vararg{<:AbstractVariable}) where {F <: Function} return __make_delta_fn_node(fform, options, args[1], args[2:end]) end diff --git a/src/nodes/delta/layouts/cvi.jl b/src/nodes/delta/layouts/cvi.jl index 115ef76d1..f020aba59 100644 --- a/src/nodes/delta/layouts/cvi.jl +++ b/src/nodes/delta/layouts/cvi.jl @@ -50,6 +50,7 @@ function deltafn_apply_layout(::CVIApproximationDeltaFnRuleLayout, ::Val{:m_out} (dependencies) -> VariationalMessage(dependencies[1], dependencies[2], messagemap) end + vmessageout = with_statics(factornode, vmessageout) vmessageout = vmessageout |> map(AbstractMessage, mapping) vmessageout = apply_pipeline_stage(pipeline_stages, factornode, vtag, vmessageout) vmessageout = vmessageout |> schedule_on(scheduler) diff --git a/src/nodes/delta/layouts/default.jl b/src/nodes/delta/layouts/default.jl index 3b53f917e..d482bc812 100644 --- a/src/nodes/delta/layouts/default.jl +++ b/src/nodes/delta/layouts/default.jl @@ -17,6 +17,15 @@ See also: [`ReactiveMP.DeltaFnDefaultKnownInverseRuleLayout`](@ref) """ struct DeltaFnDefaultRuleLayout <: AbstractDeltaNodeDependenciesLayout end +import FixedArguments + +function with_statics(factornode::DeltaFnNode, stream::T) where {T} + # We wait for the statics to be available, but ignore their actual values + # They are being injected indirectly with the `fix` function upon node creation + statics = map(static -> messageout(static, 1), FixedArguments.value.(factornode.statics)) + return combineLatest((stream, combineLatest(statics, PushNew()))) |> map(eltype(T), first) +end + # This function declares how to compute `q_out` locally around `DeltaFn` function deltafn_apply_layout(::DeltaFnDefaultRuleLayout, ::Val{:q_out}, factornode::DeltaFnNode, pipeline_stages, scheduler, addons) let out = factornode.out, localmarginal = factornode.localmarginals.marginals[1] @@ -44,7 +53,7 @@ function deltafn_apply_layout(::DeltaFnDefaultRuleLayout, ::Val{:q_ins}, factorn meta = metadata(factornode) mapping = MarginalMapping(fform, vtag, msgs_names, marginal_names, meta, factornode) - marginalout = combineLatest((msgs_observable, marginals_observable), PushNew()) |> discontinue() |> map(Marginal, mapping) + marginalout = with_statics(factornode, combineLatest((msgs_observable, marginals_observable), PushNew())) |> discontinue() |> map(Marginal, mapping) connect!(cmarginal, marginalout) # MarginalObservable has RecentSubject by default, there is no need to share_recent() here end @@ -72,6 +81,7 @@ function deltafn_apply_layout(::DeltaFnDefaultRuleLayout, ::Val{:m_out}, factorn (dependencies) -> VariationalMessage(dependencies[1], dependencies[2], messagemap) end + vmessageout = with_statics(factornode, vmessageout) vmessageout = vmessageout |> map(AbstractMessage, mapping) vmessageout = apply_pipeline_stage(pipeline_stages, factornode, vtag, vmessageout) vmessageout = vmessageout |> schedule_on(scheduler) @@ -101,6 +111,7 @@ function deltafn_apply_layout(::DeltaFnDefaultRuleLayout, ::Val{:m_in}, factorno (dependencies) -> VariationalMessage(dependencies[1], dependencies[2], messagemap) end + vmessageout = with_statics(factornode, vmessageout) vmessageout = vmessageout |> map(AbstractMessage, mapping) vmessageout = apply_pipeline_stage(pipeline_stages, factornode, vtag, vmessageout) vmessageout = vmessageout |> schedule_on(scheduler) @@ -170,6 +181,7 @@ function deltafn_apply_layout(::DeltaFnDefaultKnownInverseRuleLayout, ::Val{:m_i (dependencies) -> VariationalMessage(dependencies[1], dependencies[2], messagemap) end + vmessageout = with_statics(factornode, vmessageout) vmessageout = vmessageout |> map(AbstractMessage, mapping) vmessageout = apply_pipeline_stage(pipeline_stages, factornode, vtag, vmessageout) vmessageout = vmessageout |> schedule_on(scheduler) diff --git a/src/variables/constant.jl b/src/variables/constant.jl index c38936ff4..a03584f3c 100644 --- a/src/variables/constant.jl +++ b/src/variables/constant.jl @@ -71,6 +71,7 @@ degree(constvar::ConstVariable) = nconnected(constvar) name(constvar::ConstVariable) = constvar.name proxy_variables(constvar::ConstVariable) = nothing collection_type(constvar::ConstVariable) = constvar.collection_type +setused!(constvar::ConstVariable) = nothing isproxy(::ConstVariable) = false diff --git a/src/variables/data.jl b/src/variables/data.jl index f45aef7f5..8e0cb499d 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -101,6 +101,7 @@ proxy_variables(datavar::DataVariable) = nothing # not related to isproxy collection_type(datavar::DataVariable) = datavar.collection_type isconnected(datavar::DataVariable) = datavar.nconnected !== 0 nconnected(datavar::DataVariable) = datavar.nconnected +setused!(datavar::DataVariable) = datavar.isused = true isproxy(datavar::DataVariable) = datavar.isproxy isused(datavar::DataVariable) = datavar.isused From bef380df7b81769a05b09b621a441ea00ab2b46c Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 24 Nov 2023 13:53:21 +0100 Subject: [PATCH 2/9] add tests --- src/nodes/delta/delta.jl | 2 +- test/nodes/delta/test_delta.jl | 70 ++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 test/nodes/delta/test_delta.jl diff --git a/src/nodes/delta/delta.jl b/src/nodes/delta/delta.jl index 99225672a..a5add84aa 100644 --- a/src/nodes/delta/delta.jl +++ b/src/nodes/delta/delta.jl @@ -61,7 +61,7 @@ collect_meta(::Type{<:DeltaFn}, method::AbstractApproximationMethod) = DeltaMeta function nodefunction(factornode::DeltaFnNode) # `DeltaFnNode` `nodefunction` is `δ(y - f(ins...))` return let f = nodefunction(factornode, Val(:out)) - (y, ins...) -> ((y - f(ins...)) ≈ 0) ? 1 : 0 + (y, ins...) -> (iszero(y - f(ins...))) ? 1 : 0 end end diff --git a/test/nodes/delta/test_delta.jl b/test/nodes/delta/test_delta.jl new file mode 100644 index 000000000..2cdf428c8 --- /dev/null +++ b/test/nodes/delta/test_delta.jl @@ -0,0 +1,70 @@ +module DeltaNodeTest + +using Test, ReactiveMP, Random + +@testset "DeltaNode" begin + + @testset "Creation with static inputs (simple case) #1" begin + import ReactiveMP: nodefunction, FactorNodeCreationOptions + + foo(x, y, z) = x * y + z + + out = randomvar(:out) + + x = randomvar(:x) + y = datavar(:y, Float64) + z = constvar(:z, 3.0) + + node = make_node(foo, FactorNodeCreationOptions(nothing, Linearization(), nothing), out, x, y, z) + + update!(y, 2.0) + + for xval in rand(10) + @test nodefunction(node, Val(:out))(xval) === foo(xval, 2.0, 3.0) + @test nodefunction(node)(foo(xval, 2.0, 3.0), xval) === 1 + @test nodefunction(node)(foo(xval, 2.0, 3.0) + 1.0, xval) === 0 + end + + end + + @testset "Creation with static inputs (all permutations) #2" begin + import ReactiveMP: nodefunction, FactorNodeCreationOptions + + foo1(x, y, z) = x * y + z + foo2(x, y, z) = x / y - z + foo3(x, y, z) = x - y * z + + out = randomvar(:out) + opt = FactorNodeCreationOptions(nothing, Linearization(), nothing) + + for vals in [ rand(Float64, 3) for _ in 1:10 ], foo in (foo1, foo2, foo3) + + # In this test we attempt to create a lot of possible combinations + # of random, data and const inputs to the delta node + create_interfaces(i) = (randomvar(:x), datavar(:y, Float64), constvar(:z, vals[i])) + + for x in create_interfaces(1), y in create_interfaces(2), z in create_interfaces(3) + interfaces = [ x, y, z ] + + rpos = findall(i -> i isa RandomVariable, interfaces) + node = make_node(foo, opt, out, interfaces...) + + # data variable inputs require an actual update + foreach(enumerate(interfaces)) do (i, interface) + if interface isa DataVariable + update!(interface, vals[i]) + end + end + + @test nodefunction(node, Val(:out))(vals[rpos]...) === foo(vals...) + @test nodefunction(node)(foo(vals...), vals[rpos]...) === 1 + @test nodefunction(node)(foo(vals...) + 1, vals[rpos]...) === 0 + end + + end + + end + +end + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 2955683e7..87e64fc2e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -228,6 +228,7 @@ end addtests(testrunner, "pipeline/test_logger.jl") addtests(testrunner, "test_node.jl") + addtests(testrunner, "nodes/delta/test_delta.jl") addtests(testrunner, "nodes/flow/test_flow.jl") addtests(testrunner, "nodes/test_addition.jl") addtests(testrunner, "nodes/test_bifm.jl") From 5bd75a27053cf748a91c7817d81179851918f24c Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 24 Nov 2023 13:56:29 +0100 Subject: [PATCH 3/9] make format --- src/nodes/delta/delta.jl | 6 +++--- test/nodes/delta/test_delta.jl | 13 ++++--------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/nodes/delta/delta.jl b/src/nodes/delta/delta.jl index a5add84aa..d3fc96bf1 100644 --- a/src/nodes/delta/delta.jl +++ b/src/nodes/delta/delta.jl @@ -35,7 +35,7 @@ struct DeltaFnNode{F, P, N, S, L, M} <: AbstractFactorNode proxy::P out::NodeInterface ins::NTuple{N, IndexedNodeInterface} - + statics :: S localmarginals :: L metadata :: M @@ -61,7 +61,7 @@ collect_meta(::Type{<:DeltaFn}, method::AbstractApproximationMethod) = DeltaMeta function nodefunction(factornode::DeltaFnNode) # `DeltaFnNode` `nodefunction` is `δ(y - f(ins...))` return let f = nodefunction(factornode, Val(:out)) - (y, ins...) -> (iszero(y - f(ins...))) ? 1 : 0 + (y, ins...) -> (y - f(ins...) ≈ 0) ? 1 : 0 end end @@ -126,7 +126,7 @@ function __make_delta_fn_node(fn::F, options::FactorNodeCreationOptions, out::Ab setmessagein!(in_var, in_index, messageout(in_interface)) end - foreach(statics) do static + foreach(statics) do static setused!(FixedArguments.value(static)) end diff --git a/test/nodes/delta/test_delta.jl b/test/nodes/delta/test_delta.jl index 2cdf428c8..b9b89837d 100644 --- a/test/nodes/delta/test_delta.jl +++ b/test/nodes/delta/test_delta.jl @@ -3,14 +3,13 @@ module DeltaNodeTest using Test, ReactiveMP, Random @testset "DeltaNode" begin - @testset "Creation with static inputs (simple case) #1" begin import ReactiveMP: nodefunction, FactorNodeCreationOptions foo(x, y, z) = x * y + z out = randomvar(:out) - + x = randomvar(:x) y = datavar(:y, Float64) z = constvar(:z, 3.0) @@ -24,7 +23,6 @@ using Test, ReactiveMP, Random @test nodefunction(node)(foo(xval, 2.0, 3.0), xval) === 1 @test nodefunction(node)(foo(xval, 2.0, 3.0) + 1.0, xval) === 0 end - end @testset "Creation with static inputs (all permutations) #2" begin @@ -37,14 +35,14 @@ using Test, ReactiveMP, Random out = randomvar(:out) opt = FactorNodeCreationOptions(nothing, Linearization(), nothing) - for vals in [ rand(Float64, 3) for _ in 1:10 ], foo in (foo1, foo2, foo3) + for vals in [rand(Float64, 3) for _ in 1:10], foo in (foo1, foo2, foo3) # In this test we attempt to create a lot of possible combinations # of random, data and const inputs to the delta node create_interfaces(i) = (randomvar(:x), datavar(:y, Float64), constvar(:z, vals[i])) for x in create_interfaces(1), y in create_interfaces(2), z in create_interfaces(3) - interfaces = [ x, y, z ] + interfaces = [x, y, z] rpos = findall(i -> i isa RandomVariable, interfaces) node = make_node(foo, opt, out, interfaces...) @@ -60,11 +58,8 @@ using Test, ReactiveMP, Random @test nodefunction(node)(foo(vals...), vals[rpos]...) === 1 @test nodefunction(node)(foo(vals...) + 1, vals[rpos]...) === 0 end - end - end - end -end \ No newline at end of file +end From 54d681309dc9e564b87cdc8c837e6369717bbb46 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 24 Nov 2023 13:58:27 +0100 Subject: [PATCH 4/9] add more comments --- src/nodes/delta/delta.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/nodes/delta/delta.jl b/src/nodes/delta/delta.jl index d3fc96bf1..04a30a311 100644 --- a/src/nodes/delta/delta.jl +++ b/src/nodes/delta/delta.jl @@ -148,17 +148,22 @@ end # groups: the first group is of type `RandomVariable` and the second group is of type `ConstVariable/DataVariable` __split_static_inputs(ins::Tuple) = __split_static_inputs(Val(1), (), (), ins) -__split_static_inputs(::Val{N}, randoms, statics, ins::Tuple{}) where {N} = (randoms, statics) -__split_static_inputs(::Val{N}, randoms, statics, ins::Tuple) where {N} = __split_static_inputs(Val(N), randoms, statics, first(ins), Base.tail(ins)) +__split_static_inputs(::Val{N}, randoms, statics, remaining::Tuple{}) where {N} = (randoms, statics) +__split_static_inputs(::Val{N}, randoms, statics, remaining::Tuple) where {N} = __split_static_inputs(Val(N), randoms, statics, first(remaining), Base.tail(remaining)) +# If the current input is a random variable, we add it to the `randoms` tuple function __split_static_inputs(::Val{N}, randoms, statics, current::RandomVariable, remaining::Tuple) where {N} return __split_static_inputs(Val(N + 1), (randoms..., current), statics, remaining) end +# If the current input is a const/data variable, we add it to the `statics` tuple with its respective position function __split_static_inputs(::Val{N}, randoms, statics, current::Union{ConstVariable, DataVariable}, remaining::Tuple) where {N} return __split_static_inputs(Val(N + 1), randoms, (statics..., FixedArgument(FixedPosition(N), current)), remaining) end +# This function is used to unpack the latest value of the static variables +# For constvar we just return the value +# For datavar we get the latest value from the data stream __unpack_latest_static(_, constvar::ConstVariable) = getconst(constvar) __unpack_latest_static(_, datavar::DataVariable) = BayesBase.getpointmass(getdata(Rocket.getrecent(messageout(datavar, 1)))) From 514ec83ff142ac2ffba20a41c404980b9de8f1db Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 24 Nov 2023 13:59:18 +0100 Subject: [PATCH 5/9] 2prev --- src/nodes/delta/delta.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/nodes/delta/delta.jl b/src/nodes/delta/delta.jl index 04a30a311..7d0659822 100644 --- a/src/nodes/delta/delta.jl +++ b/src/nodes/delta/delta.jl @@ -148,7 +148,9 @@ end # groups: the first group is of type `RandomVariable` and the second group is of type `ConstVariable/DataVariable` __split_static_inputs(ins::Tuple) = __split_static_inputs(Val(1), (), (), ins) +# Stop if the `remaining` tuple is empty __split_static_inputs(::Val{N}, randoms, statics, remaining::Tuple{}) where {N} = (randoms, statics) +# Split the `remaining` tuple into head (current) and tail (remaining) __split_static_inputs(::Val{N}, randoms, statics, remaining::Tuple) where {N} = __split_static_inputs(Val(N), randoms, statics, first(remaining), Base.tail(remaining)) # If the current input is a random variable, we add it to the `randoms` tuple From e58c036f094a6f6a1bdea850246417339d05cf43 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 24 Nov 2023 15:57:42 +0100 Subject: [PATCH 6/9] fix a bug for the backward message --- src/nodes/delta/delta.jl | 6 +++++- src/nodes/delta/layouts/default.jl | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/nodes/delta/delta.jl b/src/nodes/delta/delta.jl index 7d0659822..60c47d0d7 100644 --- a/src/nodes/delta/delta.jl +++ b/src/nodes/delta/delta.jl @@ -46,7 +46,7 @@ as_node_symbol(::Type{<:DeltaFn{F}}) where {F} = Symbol(replace(string(nameof(F) functionalform(factornode::DeltaFnNode{F}) where {F} = DeltaFn{F} sdtype(factornode::DeltaFnNode) = Deterministic() interfaces(factornode::DeltaFnNode) = (factornode.out, factornode.ins...) -factorisation(factornode::DeltaFnNode{F, N}) where {F, N} = ntuple(identity, N + 1) +factorisation(factornode::DeltaFnNode{F}) where {F} = ntuple(identity, length(factornode.ins) + 1) localmarginals(factornode::DeltaFnNode) = factornode.localmarginals.marginals localmarginalnames(factornode::DeltaFnNode) = map(name, localmarginals(factornode)) metadata(factornode::DeltaFnNode) = factornode.metadata @@ -141,6 +141,10 @@ function __make_delta_fn_node(fn::F, options::FactorNodeCreationOptions, out::Ab @warn "Delta node does not support the `pipeline` option." end + if !isnothing(getinverse(meta)) && !isempty(statics) + error("The inverse function specification is not supported for the Delta node, which is connected to datavar/constvar edges.") + end + return DeltaFnNode(fn, proxy, out_interface, ins_interface, statics, localmarginals, meta) end diff --git a/src/nodes/delta/layouts/default.jl b/src/nodes/delta/layouts/default.jl index d482bc812..2286d54bf 100644 --- a/src/nodes/delta/layouts/default.jl +++ b/src/nodes/delta/layouts/default.jl @@ -19,13 +19,22 @@ struct DeltaFnDefaultRuleLayout <: AbstractDeltaNodeDependenciesLayout end import FixedArguments -function with_statics(factornode::DeltaFnNode, stream::T) where {T} +function with_statics(factornode::DeltaFnNode, stream) + return with_statics(factornode, factornode.statics, stream) +end + +function with_statics(factornode::DeltaFnNode, statics::Tuple, stream::T) where {T} # We wait for the statics to be available, but ignore their actual values # They are being injected indirectly with the `fix` function upon node creation statics = map(static -> messageout(static, 1), FixedArguments.value.(factornode.statics)) return combineLatest((stream, combineLatest(statics, PushNew()))) |> map(eltype(T), first) end +function with_statics(factornode::DeltaFnNode, statics::Tuple{}, stream::T) where {T} + # There is no need to touch the original stream if there are no statics + return stream +end + # This function declares how to compute `q_out` locally around `DeltaFn` function deltafn_apply_layout(::DeltaFnDefaultRuleLayout, ::Val{:q_out}, factornode::DeltaFnNode, pipeline_stages, scheduler, addons) let out = factornode.out, localmarginal = factornode.localmarginals.marginals[1] @@ -151,7 +160,8 @@ function deltafn_apply_layout(::DeltaFnDefaultKnownInverseRuleLayout, ::Val{:m_o end # This function declares how to compute `m_in` -function deltafn_apply_layout(::DeltaFnDefaultKnownInverseRuleLayout, ::Val{:m_in}, factornode::DeltaFnNode{F, N}, pipeline_stages, scheduler, addons) where {F, N} +function deltafn_apply_layout(::DeltaFnDefaultKnownInverseRuleLayout, ::Val{:m_in}, factornode::DeltaFnNode{F}, pipeline_stages, scheduler, addons) where {F} + N = length(factornode.ins) # For each outbound message from `in_k` edge we need an inbound messages from all OTHER! `in_*` edges and inbound message on `m_out` foreach(enumerate(factornode.ins)) do (index, interface) From d1292423746fa1f6527657234a988e35d097d87c Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 24 Nov 2023 16:00:04 +0100 Subject: [PATCH 7/9] style: make format --- src/nodes/delta/delta.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/nodes/delta/delta.jl b/src/nodes/delta/delta.jl index 60c47d0d7..953cda87a 100644 --- a/src/nodes/delta/delta.jl +++ b/src/nodes/delta/delta.jl @@ -43,13 +43,13 @@ end as_node_symbol(::Type{<:DeltaFn{F}}) where {F} = Symbol(replace(string(nameof(F)), "#" => "")) -functionalform(factornode::DeltaFnNode{F}) where {F} = DeltaFn{F} -sdtype(factornode::DeltaFnNode) = Deterministic() -interfaces(factornode::DeltaFnNode) = (factornode.out, factornode.ins...) -factorisation(factornode::DeltaFnNode{F}) where {F} = ntuple(identity, length(factornode.ins) + 1) -localmarginals(factornode::DeltaFnNode) = factornode.localmarginals.marginals -localmarginalnames(factornode::DeltaFnNode) = map(name, localmarginals(factornode)) -metadata(factornode::DeltaFnNode) = factornode.metadata +functionalform(factornode::DeltaFnNode{F}) where {F} = DeltaFn{F} +sdtype(factornode::DeltaFnNode) = Deterministic() +interfaces(factornode::DeltaFnNode) = (factornode.out, factornode.ins...) +factorisation(factornode::DeltaFnNode{F}) where {F} = ntuple(identity, length(factornode.ins) + 1) +localmarginals(factornode::DeltaFnNode) = factornode.localmarginals.marginals +localmarginalnames(factornode::DeltaFnNode) = map(name, localmarginals(factornode)) +metadata(factornode::DeltaFnNode) = factornode.metadata collect_meta(::Type{D}, something::Nothing) where {D <: DeltaFn} = error( "Delta node `$(as_node_symbol(D))` requires meta specification with the `where { meta = ... }` in the `@model` macro or with the separate `@meta` specification. See documentation for the `DeltaMeta`." From e5b3d4cb9785c8f929c8e7cfd1a1f01eb237e174 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 24 Nov 2023 16:02:10 +0100 Subject: [PATCH 8/9] make format --- src/nodes/delta/delta.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nodes/delta/delta.jl b/src/nodes/delta/delta.jl index 953cda87a..fd1cec24f 100644 --- a/src/nodes/delta/delta.jl +++ b/src/nodes/delta/delta.jl @@ -61,7 +61,7 @@ collect_meta(::Type{<:DeltaFn}, method::AbstractApproximationMethod) = DeltaMeta function nodefunction(factornode::DeltaFnNode) # `DeltaFnNode` `nodefunction` is `δ(y - f(ins...))` return let f = nodefunction(factornode, Val(:out)) - (y, ins...) -> (y - f(ins...) ≈ 0) ? 1 : 0 + (y, ins...) -> ((y - f(ins...)) ≈ 0) ? 1 : 0 end end From c16d0be0b9319724440cce8921b71862d5c29a7d Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Sun, 26 Nov 2023 10:12:06 +0100 Subject: [PATCH 9/9] fix call_rule macro for the delta node --- src/nodes/delta/delta.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nodes/delta/delta.jl b/src/nodes/delta/delta.jl index fd1cec24f..0145e0a34 100644 --- a/src/nodes/delta/delta.jl +++ b/src/nodes/delta/delta.jl @@ -88,7 +88,7 @@ call_rule_is_node_required(::Type{<:DeltaFn}) = CallRuleNodeRequired() function call_rule_make_node(::CallRuleNodeRequired, fformtype::Type{<:DeltaFn}, nodetype::F, meta::DeltaMeta) where {F} # This node is not initialized properly, but we do not expect rules to access internal uninitialized fields. # Doing so will most likely throw an error - return DeltaFnNode(nodetype, NodeInterface(:out, Marginalisation()), (), nothing, collect_meta(DeltaFn{F}, meta)) + return DeltaFnNode(nodetype, nodetype, NodeInterface(:out, Marginalisation()), (), (), nothing, collect_meta(DeltaFn{F}, meta)) end function interfaceindex(factornode::DeltaFnNode, iname::Symbol)