diff --git a/src/nodes/nodes.jl b/src/nodes/nodes.jl index bb619e199..510180a07 100644 --- a/src/nodes/nodes.jl +++ b/src/nodes/nodes.jl @@ -180,7 +180,7 @@ end # `PredefinedNodeFunctionalForm` are generally the nodes that are defined with the `@node` macro # The `UndefinedNodeFunctionalForm` nodes can be created as well, but only if the `fform` is a `Function` (see `predefined/delta.jl`) function factornode(::PredefinedNodeFunctionalForm, fform::F, interfaces::I, factorization) where {F, I} - processed_interfaces = __prepare_interfaces_generic(fform, interfaces) + processed_interfaces = prepare_interfaces_generic(fform, interfaces) localclusters = FactorNodeLocalClusters(processed_interfaces, collect_factorisation(fform, factorization)) return FactorNode(fform, processed_interfaces, localclusters) end @@ -197,14 +197,44 @@ interfaceindex(factornode::FactorNode, iname::Symbol) = interfaceindices(factornode::FactorNode, iname::Symbol) = (interfaceindex(factornode, iname),) interfaceindices(factornode::FactorNode, inames::NTuple{N, Symbol}) where {N} = map(iname -> interfaceindex(factornode, iname), inames) -# Takes a named tuple of abstract variables and converts to a tuple of NodeInterfaces with the same order -function __prepare_interfaces_generic(fform::F, interfaces::AbstractVector) where {F} +function prepare_interfaces_generic(fform::F, interfaces::AbstractVector) where {F} + prepare_interfaces_check_nonempty(fform, interfaces) + prepare_interfaces_check_adjacent_duplicates(fform, interfaces) + prepare_interfaces_check_numarguments(fform, interfaces) return map(enumerate(interfaces)) do (index, (name, variable)) return NodeInterface(alias_interface(fform, index, name), variable) end end -## activate! +function prepare_interfaces_check_nonempty(fform, interfaces) + length(interfaces) > 0 || error(lazy"At least one argument is required for a factor node. Got none for `$(fform)`") +end + +function prepare_interfaces_check_adjacent_duplicates(fform, interfaces) + # Here we create an iterator that checks ONLY adjacent interfaces + # The reason here is that we don't want to check all possible combinations of all input interfaces + # because that would require allocating an intermediate storage for `Set`, which would harm the + # performance of nodes creation. The `zip(interfaces, Iterators.drop(interfaces, 1))` creates a generic + # iterator of adjacent interface pairs + foreach(zip(interfaces, Iterators.drop(interfaces, 1))) do (left, right) + lname, _ = left + rname, _ = right + if isequal(lname, rname) + error( + lazy"`$fform` has duplicate entry for interface `$lname`. Did you pass an array (e.g. `x`) instead of an array element (e.g. `x[i]`)? Check your variable indices." + ) + end + end +end + +function prepare_interfaces_check_numarguments(fform::F, interfaces) where {F} + prepare_interfaces_check_num_inputarguments(fform, inputinterfaces(fform), interfaces) +end + +function prepare_interfaces_check_num_inputarguments(fform, inputinterfaces::Val{Input}, interfaces) where {Input} + (length(interfaces) - 1) === length(Input) || + error(lazy"Expected $(length(Input)) input arguments for `$(fform)`, got $(length(interfaces) - 1): $(join(map(first, Iterators.drop(interfaces, 1)), \", \"))") +end struct FactorNodeActivationOptions{M, D, P, A, S} metadata::M diff --git a/test/nodes/nodes_tests.jl b/test/nodes/nodes_tests.jl index bdc7d2b0a..cc35ac205 100644 --- a/test/nodes/nodes_tests.jl +++ b/test/nodes/nodes_tests.jl @@ -156,3 +156,67 @@ end @test occursin(r"DummyNodeForDocumentationStochastic.*Stochastic.*out, x, y \(or yy\)", documentation) @test occursin(r"DummyNodeForDocumentationDeterministic.*Deterministic.*out, x \(or xx, xxx\), y", documentation) end + +@testitem "Predefined nodes should check the arguments supplied" begin + struct StochasticNodeWithThreeArguments end + struct DeterministicNodeWithFourArguments end + + @node StochasticNodeWithThreeArguments Stochastic [out, x, y, z] + @node DeterministicNodeWithFourArguments Deterministic [out, x, y, z, w] + + out = randomvar() + x = randomvar() + y = randomvar() + z = randomvar() + w = randomvar() + + @test factornode(StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y), (:z, z)], ((1, 2, 3),)) isa ReactiveMP.FactorNode + @test factornode(DeterministicNodeWithFourArguments, [(:out, out), (:x, x), (:y, y), (:z, z), (:w, w)], ((1, 2, 3, 4),)) isa ReactiveMP.FactorNode + + @test_throws r"At least one argument is required for a factor node. Got none for `.*StochasticNodeWithThreeArguments`" factornode(StochasticNodeWithThreeArguments, [], ()) + @test_throws r"At least one argument is required for a factor node. Got none for `.*DeterministicNodeWithFourArguments`" factornode(DeterministicNodeWithFourArguments, [], ()) + @test_throws r"Expected 3 input arguments for `.*StochasticNodeWithThreeArguments`, got 1: x" factornode(StochasticNodeWithThreeArguments, [(:out, out), (:x, x)], ((1,),)) + @test_throws r"Expected 3 input arguments for `.*StochasticNodeWithThreeArguments`, got 2: x, y" factornode( + StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y)], ((1, 2),) + ) + @test_throws r"Expected 3 input arguments for `.*StochasticNodeWithThreeArguments`, got 4: x, y, z, w" factornode( + StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y), (:z, z), (:w, w)], ((1, 2, 3, 4),) + ) + @test_throws r"Expected 4 input arguments for `.*DeterministicNodeWithFourArguments`, got 1: x" factornode(DeterministicNodeWithFourArguments, [(:out, out), (:x, x)], ((1,),)) + @test_throws r"Expected 4 input arguments for `.*DeterministicNodeWithFourArguments`, got 2: x, y" factornode( + DeterministicNodeWithFourArguments, [(:out, out), (:x, x), (:y, y)], ((1, 2),) + ) + @test_throws r"Expected 4 input arguments for `.*DeterministicNodeWithFourArguments`, got 3: x, y, z" factornode( + DeterministicNodeWithFourArguments, [(:out, out), (:x, x), (:y, y), (:z, z)], ((1, 2, 3),) + ) + + @test_throws r"`.*StochasticNodeWithThreeArguments` has duplicate entry for interface `w`. Did you pass an array \(e.g. `x`\) instead of an array element \(e\.g\. `x\[i\]`\)\? Check your variable indices\." factornode( + StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:w, w), (:w, w)], ((1, 2, 3, 4),) + ) + @test_throws r"`.*StochasticNodeWithThreeArguments` has duplicate entry for interface `w`. Did you pass an array \(e.g. `x`\) instead of an array element \(e\.g\. `x\[i\]`\)\? Check your variable indices\." factornode( + StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y), (:z, z), (:w, w), (:w, w)], ((1, 2, 3, 4, 5, 6),) + ) +end + +@testitem "Generic node construction checks should not allocate" begin + import ReactiveMP: prepare_interfaces_check_adjacent_duplicates, prepare_interfaces_check_nonempty, prepare_interfaces_check_numarguments + + struct NodeForCheckDuplicatesTest end + @node NodeForCheckDuplicatesTest Stochastic [out, x, y, z] + + out = randomvar() + x = randomvar() + y = randomvar() + z = randomvar() + + interfaces = [(:out, out), (:x, x), (:y, y), (:z, z)] + # compile first + function foo(interfaces) + prepare_interfaces_check_nonempty(NodeForCheckDuplicatesTest, interfaces) + prepare_interfaces_check_adjacent_duplicates(NodeForCheckDuplicatesTest, interfaces) + prepare_interfaces_check_numarguments(NodeForCheckDuplicatesTest, interfaces) + end + foo(interfaces) + @test (@allocated(foo(interfaces)) == 0) + @test (@allocations(foo(interfaces)) == 0) +end