Skip to content

Commit

Permalink
Merge pull request #396 from ReactiveBayes/generic-node-check-input-a…
Browse files Browse the repository at this point in the history
…rguments

Add extra checks during the construction of a generic node
  • Loading branch information
bvdmitri authored Apr 30, 2024
2 parents 6f3d660 + a5cf3d3 commit 0c16753
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 4 deletions.
38 changes: 34 additions & 4 deletions src/nodes/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ end
# `PredefinedNodeFunctionalForm` are generally the nodes that are defined with the `@node` macro
# The `UndefinedNodeFunctionalForm` nodes can be created as well, but only if the `fform` is a `Function` (see `predefined/delta.jl`)
function factornode(::PredefinedNodeFunctionalForm, fform::F, interfaces::I, factorization) where {F, I}
processed_interfaces = __prepare_interfaces_generic(fform, interfaces)
processed_interfaces = prepare_interfaces_generic(fform, interfaces)
localclusters = FactorNodeLocalClusters(processed_interfaces, collect_factorisation(fform, factorization))
return FactorNode(fform, processed_interfaces, localclusters)
end
Expand All @@ -197,14 +197,44 @@ interfaceindex(factornode::FactorNode, iname::Symbol) =
interfaceindices(factornode::FactorNode, iname::Symbol) = (interfaceindex(factornode, iname),)
interfaceindices(factornode::FactorNode, inames::NTuple{N, Symbol}) where {N} = map(iname -> interfaceindex(factornode, iname), inames)

# Takes a named tuple of abstract variables and converts to a tuple of NodeInterfaces with the same order
function __prepare_interfaces_generic(fform::F, interfaces::AbstractVector) where {F}
function prepare_interfaces_generic(fform::F, interfaces::AbstractVector) where {F}
prepare_interfaces_check_nonempty(fform, interfaces)
prepare_interfaces_check_adjacent_duplicates(fform, interfaces)
prepare_interfaces_check_numarguments(fform, interfaces)
return map(enumerate(interfaces)) do (index, (name, variable))
return NodeInterface(alias_interface(fform, index, name), variable)
end
end

## activate!
function prepare_interfaces_check_nonempty(fform, interfaces)
length(interfaces) > 0 || error(lazy"At least one argument is required for a factor node. Got none for `$(fform)`")
end

function prepare_interfaces_check_adjacent_duplicates(fform, interfaces)
# Here we create an iterator that checks ONLY adjacent interfaces
# The reason here is that we don't want to check all possible combinations of all input interfaces
# because that would require allocating an intermediate storage for `Set`, which would harm the
# performance of nodes creation. The `zip(interfaces, Iterators.drop(interfaces, 1))` creates a generic
# iterator of adjacent interface pairs
foreach(zip(interfaces, Iterators.drop(interfaces, 1))) do (left, right)
lname, _ = left
rname, _ = right
if isequal(lname, rname)
error(
lazy"`$fform` has duplicate entry for interface `$lname`. Did you pass an array (e.g. `x`) instead of an array element (e.g. `x[i]`)? Check your variable indices."
)
end
end
end

function prepare_interfaces_check_numarguments(fform::F, interfaces) where {F}
prepare_interfaces_check_num_inputarguments(fform, inputinterfaces(fform), interfaces)
end

function prepare_interfaces_check_num_inputarguments(fform, inputinterfaces::Val{Input}, interfaces) where {Input}
(length(interfaces) - 1) === length(Input) ||
error(lazy"Expected $(length(Input)) input arguments for `$(fform)`, got $(length(interfaces) - 1): $(join(map(first, Iterators.drop(interfaces, 1)), \", \"))")
end

struct FactorNodeActivationOptions{M, D, P, A, S}
metadata::M
Expand Down
64 changes: 64 additions & 0 deletions test/nodes/nodes_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,67 @@ end
@test occursin(r"DummyNodeForDocumentationStochastic.*Stochastic.*out, x, y \(or yy\)", documentation)
@test occursin(r"DummyNodeForDocumentationDeterministic.*Deterministic.*out, x \(or xx, xxx\), y", documentation)
end

@testitem "Predefined nodes should check the arguments supplied" begin
struct StochasticNodeWithThreeArguments end
struct DeterministicNodeWithFourArguments end

@node StochasticNodeWithThreeArguments Stochastic [out, x, y, z]
@node DeterministicNodeWithFourArguments Deterministic [out, x, y, z, w]

out = randomvar()
x = randomvar()
y = randomvar()
z = randomvar()
w = randomvar()

@test factornode(StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y), (:z, z)], ((1, 2, 3),)) isa ReactiveMP.FactorNode
@test factornode(DeterministicNodeWithFourArguments, [(:out, out), (:x, x), (:y, y), (:z, z), (:w, w)], ((1, 2, 3, 4),)) isa ReactiveMP.FactorNode

@test_throws r"At least one argument is required for a factor node. Got none for `.*StochasticNodeWithThreeArguments`" factornode(StochasticNodeWithThreeArguments, [], ())
@test_throws r"At least one argument is required for a factor node. Got none for `.*DeterministicNodeWithFourArguments`" factornode(DeterministicNodeWithFourArguments, [], ())
@test_throws r"Expected 3 input arguments for `.*StochasticNodeWithThreeArguments`, got 1: x" factornode(StochasticNodeWithThreeArguments, [(:out, out), (:x, x)], ((1,),))
@test_throws r"Expected 3 input arguments for `.*StochasticNodeWithThreeArguments`, got 2: x, y" factornode(
StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y)], ((1, 2),)
)
@test_throws r"Expected 3 input arguments for `.*StochasticNodeWithThreeArguments`, got 4: x, y, z, w" factornode(
StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y), (:z, z), (:w, w)], ((1, 2, 3, 4),)
)
@test_throws r"Expected 4 input arguments for `.*DeterministicNodeWithFourArguments`, got 1: x" factornode(DeterministicNodeWithFourArguments, [(:out, out), (:x, x)], ((1,),))
@test_throws r"Expected 4 input arguments for `.*DeterministicNodeWithFourArguments`, got 2: x, y" factornode(
DeterministicNodeWithFourArguments, [(:out, out), (:x, x), (:y, y)], ((1, 2),)
)
@test_throws r"Expected 4 input arguments for `.*DeterministicNodeWithFourArguments`, got 3: x, y, z" factornode(
DeterministicNodeWithFourArguments, [(:out, out), (:x, x), (:y, y), (:z, z)], ((1, 2, 3),)
)

@test_throws r"`.*StochasticNodeWithThreeArguments` has duplicate entry for interface `w`. Did you pass an array \(e.g. `x`\) instead of an array element \(e\.g\. `x\[i\]`\)\? Check your variable indices\." factornode(
StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:w, w), (:w, w)], ((1, 2, 3, 4),)
)
@test_throws r"`.*StochasticNodeWithThreeArguments` has duplicate entry for interface `w`. Did you pass an array \(e.g. `x`\) instead of an array element \(e\.g\. `x\[i\]`\)\? Check your variable indices\." factornode(
StochasticNodeWithThreeArguments, [(:out, out), (:x, x), (:y, y), (:z, z), (:w, w), (:w, w)], ((1, 2, 3, 4, 5, 6),)
)
end

@testitem "Generic node construction checks should not allocate" begin
import ReactiveMP: prepare_interfaces_check_adjacent_duplicates, prepare_interfaces_check_nonempty, prepare_interfaces_check_numarguments

struct NodeForCheckDuplicatesTest end
@node NodeForCheckDuplicatesTest Stochastic [out, x, y, z]

out = randomvar()
x = randomvar()
y = randomvar()
z = randomvar()

interfaces = [(:out, out), (:x, x), (:y, y), (:z, z)]
# compile first
function foo(interfaces)
prepare_interfaces_check_nonempty(NodeForCheckDuplicatesTest, interfaces)
prepare_interfaces_check_adjacent_duplicates(NodeForCheckDuplicatesTest, interfaces)
prepare_interfaces_check_numarguments(NodeForCheckDuplicatesTest, interfaces)
end
foo(interfaces)
@test (@allocated(foo(interfaces)) == 0)
@test (@allocations(foo(interfaces)) == 0)
end

0 comments on commit 0c16753

Please sign in to comment.