Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify variables structures for predictions functionality #248

Merged
merged 76 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
da44b05
Add MAR node
albertpod Oct 17, 2022
ef8e17e
Add rules prototypes
albertpod Oct 20, 2022
35aba35
Add rules for MAR
albertpod Oct 28, 2022
3462700
Update MAR
albertpod Oct 31, 2022
b160302
Merge branch 'master' into mar
albertpod Nov 13, 2022
24cfeaf
Merge branch 'master' into mar
albertpod Nov 27, 2022
91179d1
Update rules
albertpod Nov 28, 2022
6795aff
Update rules
albertpod Dec 30, 2022
bd7be22
WIP: Update mask MAR function
albertpod Jan 2, 2023
86a24e4
Bug fix
albertpod Jan 2, 2023
dcb5bd5
Update rules
albertpod Jan 3, 2023
adc58ad
Update rules
albertpod Jan 3, 2023
ee8b471
project: add BlockArrays dependency
bvdmitri Jan 4, 2023
a41b0d9
Merge branch 'mar' of github.com:biaslab/ReactiveMP.jl into mar
bvdmitri Jan 4, 2023
4d3e221
fix constructor
bvdmitri Jan 4, 2023
09140d1
Update rules
albertpod Jan 4, 2023
9e15486
Merge branch 'mar' of https://github.com/biaslab/ReactiveMP.jl into mar
albertpod Jan 4, 2023
78c5216
Update FE
albertpod Jan 4, 2023
a723fcd
Merge branch 'master' into mar
ismailsenoz Jan 4, 2023
67d47cc
Update rule
albertpod Jan 5, 2023
96731d0
Merge branch 'mar' of https://github.com/biaslab/ReactiveMP.jl into mar
albertpod Jan 5, 2023
f2fd4d5
Update MAR rules
albertpod Jan 6, 2023
5eab7be
Update rules
albertpod Jan 6, 2023
395ea37
WIP: Update marginals & lambda
albertpod Jan 6, 2023
2725813
Fix bug
albertpod Jan 7, 2023
cfc4243
Fix backward rule
albertpod Jan 9, 2023
7a91826
Update rules
albertpod Jan 10, 2023
d5248e0
Fix FE
albertpod Jan 10, 2023
01f59f1
Clean up
albertpod Jan 10, 2023
a839f3d
Update rules
albertpod Jan 11, 2023
b52f77b
Update MF rules
albertpod Jan 11, 2023
87399c2
Modify variables structures for predictions functionality
albertpod Jan 24, 2023
f173652
Merge branch 'master' into dev-predict
albertpod Jan 24, 2023
ae0e770
Make format
albertpod Jan 24, 2023
e2d56e8
Merge branch 'master' into dev-predict
albertpod Jan 30, 2023
696f8ea
WIP: Change data
albertpod Jan 30, 2023
22ce46e
feat: add allows_missings function & tests
bvdmitri Jan 30, 2023
54074d3
Make format
albertpod Feb 1, 2023
fe57269
Merge branch 'master' into mar
albertpod Feb 1, 2023
d9564ed
improve factorisation logic for prediction variables
bvdmitri Feb 1, 2023
33d7ebd
Merge branch 'dev-predict' into mar
albertpod Feb 1, 2023
e936392
fix: update warning for factorisation check
bvdmitri Feb 1, 2023
2ae7b4d
Merge branch 'dev-predict' into mar
albertpod Feb 2, 2023
07c5bdc
Update mapping for marginal
albertpod Feb 6, 2023
b99d989
Merge branch 'master' into dev-predict
albertpod Feb 6, 2023
9179163
Merge branch 'dev-predict' into mar
albertpod Feb 6, 2023
37d0a72
Make format
albertpod Feb 7, 2023
2bb563f
Delete WIPs
albertpod Feb 7, 2023
65cdae8
Make format
albertpod Feb 7, 2023
395cc62
Merge branch 'dev-predict' into mar
albertpod Feb 7, 2023
2a2aa82
Merge branch 'master' into mar
albertpod Feb 13, 2023
cbbc8f1
Merge branch 'master' into dev-predict
albertpod Feb 22, 2023
dd05659
Merge branch 'dev-predict' into mar
albertpod Feb 22, 2023
19547be
Merge master into dev-predict
albertpod Mar 6, 2023
a490f8a
fix tests
bvdmitri Mar 7, 2023
eaae1d2
Update rules
albertpod Mar 19, 2023
dc836fe
Fix MAR rules
albertpod Mar 19, 2023
644a8e8
Merge branch 'dev-predict' into mar
albertpod Mar 19, 2023
a3fcc37
Decrease allocs
albertpod Mar 19, 2023
c22d44a
Merge branch 'master' into mar
albertpod Mar 19, 2023
db06986
Optmize functions
albertpod Mar 21, 2023
3b4bd2b
Merge branch 'master' into dev-predict
albertpod Mar 21, 2023
ac9d8c0
Merge branch 'master' into mar
albertpod Mar 30, 2023
b89e419
Remove diffs
albertpod Mar 30, 2023
3ff75e8
Merge branch 'master' into mar
albertpod May 29, 2023
a6ce995
Merge branch 'master' into dev-predict
albertpod Jun 19, 2023
a84991e
Merge branch 'master' into mar
albertpod Jul 23, 2023
b04fb89
Merge branch 'mar' into dev-predict
albertpod Jul 23, 2023
59a79b0
Merge branch 'master' into dev-predict
albertpod Sep 8, 2023
e0edea1
Remove MV autoregressive node
albertpod Sep 8, 2023
0643188
Remove mv autoregressive from ReactiveMP.jl
albertpod Sep 12, 2023
c659769
Merge branch 'master' into dev-predict
albertpod Sep 12, 2023
a248d5b
Remove not needed exports
albertpod Sep 12, 2023
5ddad79
Remove BlockArrays
albertpod Sep 12, 2023
0c94a12
Update src/variables/data.jl
albertpod Sep 18, 2023
ad7b165
fix warning for predicted datavars
bvdmitri Sep 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@ uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3"
authors = ["Dmitry Bagaev <d.v.bagaev@tue.nl>", "Albert Podusenko <a.podusenko@tue.nl>", "Bart van Erp <b.v.erp@tue.nl>", "Ismail Senoz <i.senoz@tue.nl>"]
version = "3.9.3"

[weakdeps]
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[extensions]
ReactiveMPOptimisersExt = "Optimisers"
ReactiveMPZygoteExt = "Zygote"
ReactiveMPRequiresExt = "Requires"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -38,6 +28,16 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"

[weakdeps]
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
ReactiveMPOptimisersExt = "Optimisers"
ReactiveMPRequiresExt = "Requires"
ReactiveMPZygoteExt = "Zygote"

[compat]
DataStructures = "0.17, 0.18"
Distributions = "0.24, 0.25"
Expand Down Expand Up @@ -73,9 +73,9 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
5 changes: 4 additions & 1 deletion src/constraints/specifications/constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@
foreach(constraints.factorisation) do spec
specnames = getnames(spec)
foreach(specnames) do specname
if warn && (hasdatavar(variables, specname) || hasconstvar(variables, specname))
if hasdatavar(variables, specname) && allows_missings(variables[specname])

Check warning on line 107 in src/constraints/specifications/constraints.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/constraints.jl#L107

Added line #L107 was not covered by tests
# skip, because it is fine to have a datavar in the factorization constraint, which allows missings
nothing
elseif warn && (hasdatavar(variables, specname) || hasconstvar(variables, specname))

Check warning on line 110 in src/constraints/specifications/constraints.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/constraints.jl#L109-L110

Added lines #L109 - L110 were not covered by tests
@warn "Constraints specification has factorisation constraint for `q($(join(specnames, ", ")))`, but `$(specname)` is not a random variable. Data variables and constants in the model are forced to be factorized by default such that `q($(join(specnames, ", "))) = q($(specname))q(...)` . Use `warn = false` option during constraints specification to suppress this warning."
elseif warn && !hasrandomvar(variables, specname)
@warn "Constraints specification has factorisation constraint for `q($(join(specnames, ", ")))`, but variables collection has no random variable named `$(specname)`. Use `warn = false` option during constraints specification to suppress this warning."
Expand Down
68 changes: 54 additions & 14 deletions src/constraints/specifications/factorisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,19 +187,56 @@
# Preoptimised dispatch rule for unspecified constraints and a deterministic node with any number of inputs
resolve_factorisation(::UnspecifiedConstraints, ::Deterministic, allvariables, fform, variables) = FullFactorisation()

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 2 inputs
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: RandomVariable} = ((1, 2),)
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: RandomVariable} = ((1,), (2,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: Union{<:ConstVariable, <:DataVariable}} = ((1,), (2,))

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: RandomVariable} = ((1, 2, 3),)
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: RandomVariable, V3 <: RandomVariable} = ((1,), (2, 3))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: Union{<:ConstVariable, <:DataVariable}, V3 <: RandomVariable} = ((1, 3), (2,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: Union{<:ConstVariable, <:DataVariable}} = ((1, 2), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: Union{<:ConstVariable, <:DataVariable}, V3 <: Union{<:ConstVariable, <:DataVariable}} = ((1,), (2,), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: RandomVariable, V3 <: Union{<:ConstVariable, <:DataVariable}} = ((1,), (2,), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: Union{<:ConstVariable, <:DataVariable}, V3 <: RandomVariable} = ((1,), (2,), (3,))
# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 2 inputs, random variable & constant variable
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: RandomVariable} = ((1, 2),)
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: ConstVariable, V2 <: RandomVariable} = ((1,), (2,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: ConstVariable} = ((1,), (2,))

Check warning on line 193 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L191-L193

Added lines #L191 - L193 were not covered by tests

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 2 inputs, random variable & data variable
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2}) where {V1 <: DataVariable, V2 <: RandomVariable} =

Check warning on line 196 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L196

Added line #L196 was not covered by tests
allows_missings(variables[1]) ? ((1, 2),) : ((1,), (2,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: DataVariable} =

Check warning on line 198 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L198

Added line #L198 was not covered by tests
allows_missings(variables[2]) ? ((1, 2),) : ((1,), (2,))

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs, random variable & constant variables
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: RandomVariable} = ((1, 2, 3),)
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: ConstVariable, V2 <: RandomVariable, V3 <: RandomVariable} = ((1,), (2, 3))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: ConstVariable, V3 <: RandomVariable} = ((1, 3), (2,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: ConstVariable} = ((1, 2), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: ConstVariable, V3 <: ConstVariable} = ((1,), (2,), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: ConstVariable, V2 <: RandomVariable, V3 <: ConstVariable} = ((1,), (2,), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: ConstVariable, V2 <: ConstVariable, V3 <: RandomVariable} = ((1,), (2,), (3,))

Check warning on line 208 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L202-L208

Added lines #L202 - L208 were not covered by tests

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs, random variable & data variable
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: DataVariable, V2 <: RandomVariable, V3 <: RandomVariable} = allows_missings(variables[1]) ? ((1, 2, 3),) : ((1,), (2, 3))

Check warning on line 213 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L213

Added line #L213 was not covered by tests
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: RandomVariable, V2 <: DataVariable, V3 <: RandomVariable} = allows_missings(variables[2]) ? ((1, 2, 3),) : ((1, 3), (2,))

Check warning on line 216 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L216

Added line #L216 was not covered by tests
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: DataVariable} = allows_missings(variables[3]) ? ((1, 2, 3),) : ((1, 2), (3,))

Check warning on line 219 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L219

Added line #L219 was not covered by tests

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs, random variable & data variable & const variable
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: DataVariable, V2 <: ConstVariable, V3 <: RandomVariable} = allows_missings(variables[1]) ? ((1, 3), (2,)) : ((1,), (2,), (3,))

Check warning on line 224 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L224

Added line #L224 was not covered by tests
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: DataVariable, V2 <: RandomVariable, V3 <: ConstVariable} = allows_missings(variables[1]) ? ((1, 2), (3,)) : ((1,), (2,), (3,))

Check warning on line 227 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L227

Added line #L227 was not covered by tests
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: ConstVariable, V2 <: DataVariable, V3 <: RandomVariable} = allows_missings(variables[2]) ? ((1,), (2, 3)) : ((1,), (3,), (2,))

Check warning on line 230 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L230

Added line #L230 was not covered by tests
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: RandomVariable, V2 <: DataVariable, V3 <: ConstVariable} = allows_missings(variables[2]) ? ((1, 2), (3,)) : ((1,), (2,), (3,))

Check warning on line 233 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L233

Added line #L233 was not covered by tests
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: ConstVariable, V2 <: RandomVariable, V3 <: DataVariable} = allows_missings(variables[3]) ? ((1,), (2, 3)) : ((1,), (2,), (3,))

Check warning on line 236 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L236

Added line #L236 was not covered by tests
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: RandomVariable, V2 <: ConstVariable, V3 <: DataVariable} = allows_missings(variables[3]) ? ((1, 3), (2,)) : ((1,), (2,), (3,))

Check warning on line 239 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L239

Added line #L239 was not covered by tests

"""
resolve_factorisation(constraints, allvariables, fform, variables)
Expand Down Expand Up @@ -419,8 +456,11 @@
index::Int = 1
shift::Int = 0
for varref in var_refs
if israndom(varref[3])
if israndom(varref[3]) || (isdata(varref[3]) && allows_missings(varref[3]))

Check warning on line 459 in src/constraints/specifications/factorisation.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints/specifications/factorisation.jl#L459

Added line #L459 was not covered by tests
# We process everything as usual if varref is a random variable
# or if the variable is data variable and it allows missing
# We probably should change the logic from "allows missings" to "used as prediction"
# For now we assume that if data variable allows missing input it is indeed "used as prediction"
__process_factorisation_entry!(varref[1], varref[2], shift)
else
# We filter out varref from all clusters if it is not random
Expand Down
8 changes: 7 additions & 1 deletion src/marginal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,13 @@
# Marginal is initial if it is not clamped and all of the inputs are either clamped or initial
is_marginal_initial = !is_marginal_clamped && (__check_all(is_clamped_or_initial, messages) && __check_all(is_clamped_or_initial, marginals))

marginal = marginalrule(marginal_mapping_fform(mapping), mapping.vtag, mapping.msgs_names, messages, mapping.marginals_names, marginals, mapping.meta, mapping.factornode)
marginal = if !isnothing(messages) && any(ismissing, TupleTools.flatten(getdata.(messages)))
missing
elseif !isnothing(marginals) && any(ismissing, TupleTools.flatten(getdata.(marginals)))
missing

Check warning on line 201 in src/marginal.jl

View check run for this annotation

Codecov / codecov/patch

src/marginal.jl#L198-L201

Added lines #L198 - L201 were not covered by tests
else
marginalrule(marginal_mapping_fform(mapping), mapping.vtag, mapping.msgs_names, messages, mapping.marginals_names, marginals, mapping.meta, mapping.factornode)

Check warning on line 203 in src/marginal.jl

View check run for this annotation

Codecov / codecov/patch

src/marginal.jl#L203

Added line #L203 was not covered by tests
end

return Marginal(marginal, is_marginal_clamped, is_marginal_initial, nothing)
end
Expand Down
30 changes: 18 additions & 12 deletions src/message.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,24 @@
# Message is initial if it is not clamped and all of the inputs are either clamped or initial
is_message_initial = !is_message_clamped && (__check_all(is_clamped_or_initial, messages) && __check_all(is_clamped_or_initial, marginals))

result, addons = rule(
message_mapping_fform(mapping),
mapping.vtag,
mapping.vconstraint,
mapping.msgs_names,
messages,
mapping.marginals_names,
marginals,
mapping.meta,
mapping.addons,
mapping.factornode
)
result, addons = if !isnothing(messages) && any(ismissing, TupleTools.flatten(getdata.(messages)))
missing, mapping.addons
elseif !isnothing(marginals) && any(ismissing, TupleTools.flatten(getdata.(marginals)))
missing, mapping.addons

Check warning on line 330 in src/message.jl

View check run for this annotation

Codecov / codecov/patch

src/message.jl#L327-L330

Added lines #L327 - L330 were not covered by tests
else
rule(

Check warning on line 332 in src/message.jl

View check run for this annotation

Codecov / codecov/patch

src/message.jl#L332

Added line #L332 was not covered by tests
message_mapping_fform(mapping),
mapping.vtag,
mapping.vconstraint,
mapping.msgs_names,
messages,
mapping.marginals_names,
marginals,
mapping.meta,
mapping.addons,
mapping.factornode
)
end

# Inject extra addons after the rule has been executed
addons = message_mapping_addons(mapping, getdata(messages), getdata(marginals), result, addons)
Expand Down
2 changes: 1 addition & 1 deletion src/node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ macro node(fformtype, sdtype, interfaces_list)
missingclustererr = "Cannot find the cluster for the variable connected to the `$(name)` interface around the `$fformtype` node."
quote
# If a variable `$name` is a constvar or a datavar
if ReactiveMP.isconst($(name)) || ReactiveMP.isdata($(name))
if ReactiveMP.isconst($(name)) || (ReactiveMP.isdata($(name)) && !ReactiveMP.allows_missings($(name)))
local __factorisation = ReactiveMP.factorisation(node)
# Find the factorization cluster associated with the constvar `$name`
local __index = ReactiveMP.interface_get_index(Val{$(QuoteNode(fbottomtype))}, Val{$(QuoteNode(name))})
Expand Down
41 changes: 36 additions & 5 deletions src/variables/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
mutable struct DataVariable{D, S} <: AbstractVariable
name :: Symbol
collection_type :: AbstractVariableCollectionType
prediction :: MarginalObservable
input_messages :: Vector{MessageObservable{AbstractMessage}}
messageout :: S
nconnected :: Int
isproxy :: Bool
Expand Down Expand Up @@ -74,12 +76,16 @@
datavar(name::Symbol, ::Type{D}, dims::Vararg{Int}) where {D} = datavar(DataVariableCreationOptions(D), name, D, dims)

datavar(options::DataVariableCreationOptions{S}, name::Symbol, ::Type{D}, collection_type::AbstractVariableCollectionType = VariableIndividual()) where {S, D} =
DataVariable{D, S}(name, collection_type, options.subject, 0, options.isproxy, options.isused)
DataVariable{D, S}(name, collection_type, MarginalObservable(), Vector{MessageObservable{AbstractMessage}}(), options.subject, 0, options.isproxy, options.isused)

function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, length::Int) where {D}
return map(i -> datavar(similar(options), name, D, VariableVector(i)), 1:length)
end

function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, dim1::Int, extra_dims::Vararg{Int}) where {D}
return datavar(options, name, D, (dim1, extra_dims...))

Check warning on line 86 in src/variables/data.jl

View check run for this annotation

Codecov / codecov/patch

src/variables/data.jl#L85-L86

Added lines #L85 - L86 were not covered by tests
end

function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, dims::Tuple) where {D}
indices = CartesianIndices(dims)
size = axes(indices)
Expand All @@ -106,11 +112,17 @@
isconst(::DataVariable) = false
isconst(::AbstractArray{<:DataVariable}) = false

allows_missings(datavar::DataVariable) = allows_missings(datavar, eltype(datavar.messageout))

allows_missings(datavars::AbstractArray{<:DataVariable}) = all(allows_missings, datavars)

Check warning on line 117 in src/variables/data.jl

View check run for this annotation

Codecov / codecov/patch

src/variables/data.jl#L117

Added line #L117 was not covered by tests
allows_missings(datavar::DataVariable, ::Type{Message{D}}) where {D} = false
allows_missings(datavar::DataVariable, ::Type{Union{Message{Missing}, Message{D}}} where {D}) = true

function Base.getindex(datavar::DataVariable, i...)
error("Variable $(indexed_name(datavar)) has been indexed with `[$(join(i, ','))]`. Direct indexing of `data` variables is not allowed.")
end

getlastindex(::DataVariable) = 1
getlastindex(datavar::DataVariable) = degree(datavar) + 1

messageout(datavar::DataVariable, ::Int) = datavar.messageout
messagein(datavar::DataVariable, ::Int) = error("It is not possible to get a reference for inbound message for datavar")
Expand Down Expand Up @@ -163,8 +175,27 @@

setanonymous!(::DataVariable, ::Bool) = nothing

function setmessagein!(datavar::DataVariable, ::Int, messagein)
datavar.nconnected += 1
datavar.isused = true
function setmessagein!(datavar::DataVariable, index::Int, messagein)
if index === (degree(datavar) + 1)
push!(datavar.input_messages, messagein)
datavar.nconnected += 1
datavar.isused = true
else
error(
"Inconsistent state in setmessagein! function for data variable $(datavar). `index` should be equal to `degree(datavar) + 1 = $(degree(datavar) + 1)`, $(index) is given instead"
)
end
return nothing
end

marginal_prod_fn(datavar::DataVariable) = marginal_prod_fn(FoldLeftProdStrategy(), ProdAnalytical(), UnspecifiedFormConstraint(), FormConstraintCheckLast())

Check warning on line 191 in src/variables/data.jl

View check run for this annotation

Codecov / codecov/patch

src/variables/data.jl#L191

Added line #L191 was not covered by tests

_getprediction(datavar::DataVariable) = datavar.prediction
_setprediction!(datavar::DataVariable, observable) = connect!(_getprediction(datavar), observable)
_makeprediction(datavar::DataVariable) = collectLatest(AbstractMessage, Marginal, datavar.input_messages, marginal_prod_fn(datavar))

Check warning on line 195 in src/variables/data.jl

View check run for this annotation

Codecov / codecov/patch

src/variables/data.jl#L193-L195

Added lines #L193 - L195 were not covered by tests

function activate!(datavar::DataVariable, options)
_setprediction!(datavar, _makeprediction(datavar))

Check warning on line 198 in src/variables/data.jl

View check run for this annotation

Codecov / codecov/patch

src/variables/data.jl#L197-L198

Added lines #L197 - L198 were not covered by tests

return nothing
end
5 changes: 4 additions & 1 deletion src/variables/variable.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export AbstractVariable, degree
export is_clamped, is_marginalisation, is_moment_matching
export FoldLeftProdStrategy, FoldRightProdStrategy, CustomProdStrategy
export getmarginal, getmarginals, setmarginal!, setmarginals!, name, as_variable
export getprediction, getpredictions, getmarginal, getmarginals, setmarginal!, setmarginals!, name, as_variable
export setmessage!, setmessages!

using Rocket
Expand Down Expand Up @@ -80,6 +80,9 @@
# Helper functions
# Getters

getprediction(variable::AbstractVariable) = _getprediction(variable)
getpredictions(variables::AbstractArray{<:AbstractVariable}) = collectLatest(map(v -> getprediction(v), variables))

Check warning on line 84 in src/variables/variable.jl

View check run for this annotation

Codecov / codecov/patch

src/variables/variable.jl#L83-L84

Added lines #L83 - L84 were not covered by tests

getmarginal(variable::AbstractVariable) = getmarginal(variable, SkipInitial())
getmarginal(variable::AbstractVariable, skip_strategy::MarginalSkipStrategy) = apply_skip_filter(_getmarginal(variable), skip_strategy)

Expand Down
Loading
Loading