From 4e7e12222a5cd834b813a0e2b63be24aa0bd14cf Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 19 Apr 2024 17:31:41 +0200 Subject: [PATCH 1/4] Redirect both message and marginals for linked data variables --- src/message.jl | 7 +- src/variables/data.jl | 32 ++++-- test/nodes/predefined/delta/delta_tests.jl | 26 ++++- test/testutilities.jl | 5 +- test/variables/data_tests.jl | 127 +++++++++++++++++++++ 5 files changed, 178 insertions(+), 19 deletions(-) diff --git a/src/message.jl b/src/message.jl index 2e3787be8..4bb495daf 100644 --- a/src/message.jl +++ b/src/message.jl @@ -241,12 +241,13 @@ dropproxytype(::Type{<:Message{T}}) where {T} = T ## Message observable -struct MessageObservable{M <: AbstractMessage} <: Subscribable{M} - subject :: Rocket.RecentSubjectInstance{M, Subject{M, AsapScheduler, AsapScheduler}} +struct MessageObservable{M <: AbstractMessage, S <: AbstractSubject} <: Subscribable{M} + subject :: S stream :: LazyObservable{M} end -MessageObservable(::Type{M} = AbstractMessage) where {M} = MessageObservable{M}(RecentSubject(M), lazy(M)) +MessageObservable(::Type{M} = AbstractMessage) where {M} = MessageObservable(RecentSubject(M), lazy(M)) +MessageObservable(subject::AbstractSubject{M}) where {M} = MessageObservable(subject, lazy(M)) Rocket.getrecent(observable::MessageObservable) = Rocket.getrecent(observable.subject) diff --git a/src/variables/data.jl b/src/variables/data.jl index 9d1d93393..d43676c22 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -1,17 +1,19 @@ export datavar, DataVariable, update!, DataVariableActivationOptions mutable struct DataVariable{M, P} <: AbstractVariable + datastream :: M input_messages :: Vector{MessageObservable{AbstractMessage}} + messageout :: MessageObservable{Message} marginal :: MarginalObservable - messageout :: M prediction :: P end function DataVariable() - messageout = RecentSubject(Message) - marginal = MarginalObservable() + datastream = Subject(Message) + messageout = MessageObservable(Message) + marginal = MarginalObservable() prediction = MarginalObservable() - return DataVariable(Vector{MessageObservable{AbstractMessage}}(), marginal, messageout, prediction) + return DataVariable(datastream, Vector{MessageObservable{AbstractMessage}}(), messageout, marginal, prediction) end datavar() = DataVariable() @@ -46,22 +48,28 @@ struct DataVariableActivationOptions args end +DataVariableActivationOptions() = DataVariableActivationOptions(false, false, nothing, nothing) + function activate!(datavar::DataVariable, options::DataVariableActivationOptions) - if true # options.prediction + if options.prediction _setprediction!(datavar, _makeprediction(datavar)) end - # If the variable is not linked to another we simply redirect the message as a marginal + # If the variable is not linked to another we simply redirect the message from the datastream if !options.linked - connect!(datavar.marginal, datavar.messageout |> map(Marginal, as_marginal)) + connect!(datavar.messageout, datavar.datastream) else # If the variable is linked to another we need to apply a transformation from the linked variables linkvalues = combineLatestUpdates(map(l -> __link_getmarginal(l), options.args)) - linkstream = linkvalues |> map(Marginal, (args) -> let f = options.transform - return Marginal(__apply_link(f, getrecent.(args)), false, false, nothing) + linkstream = linkvalues |> map(Message, (args) -> let f = options.transform + return Message(__apply_link(f, getrecent.(args)), false, false, nothing) end) - connect!(datavar.marginal, linkstream) + connect!(datavar.messageout, merged((datavar.datastream, linkstream))) end + + # The marginal stream is always the same as the message out + connect!(datavar.marginal, datavar.messageout |> map(Marginal, as_marginal)) + return nothing end @@ -76,8 +84,8 @@ _getmarginal(datavar::DataVariable) = datavar.marginal _setmarginal!(::DataVariable, observable) = error("It is not possible to set a marginal stream for `DataVariable`") _makemarginal(::DataVariable) = error("It is not possible to make marginal stream for `DataVariable`") -update!(datavar::DataVariable, data) = next!(messageout(datavar, 1), Message(PointMass(data), false, false, nothing)) -update!(datavar::DataVariable, ::Missing) = next!(messageout(datavar, 1), Message(missing, false, false, nothing)) +update!(datavar::DataVariable, data) = next!(datavar.datastream, Message(PointMass(data), false, false, nothing)) +update!(datavar::DataVariable, ::Missing) = next!(datavar.datastream, Message(missing, false, false, nothing)) function update!(datavars::AbstractArray{<:DataVariable}, data::AbstractArray) @assert size(datavars) === size(data) """ diff --git a/test/nodes/predefined/delta/delta_tests.jl b/test/nodes/predefined/delta/delta_tests.jl index 1254e2d05..ac3611681 100644 --- a/test/nodes/predefined/delta/delta_tests.jl +++ b/test/nodes/predefined/delta/delta_tests.jl @@ -1,6 +1,7 @@ @testitem "DeltaNode - creation with static inputs (simple case) #1" begin - import ReactiveMP: nodefunction, DeltaMeta, Linearization + using Rocket + import ReactiveMP: nodefunction, DeltaMeta, Linearization, messageout, activate!, RandomVariableActivationOptions, DataVariableActivationOptions foo(x, y, z) = x * y + z @@ -13,6 +14,12 @@ node = factornode(foo, [(:out, out), (:in, x), (:in, y), (:in, z)], ((1, 2, 3, 4),)) meta = DeltaMeta(method = Linearization()) + activate!(x, RandomVariableActivationOptions()) + activate!(y, DataVariableActivationOptions()) + # We need to subscribe on the message, otherwise the recent message + # will not be propagated to the nodefunction + subscribe!(messageout(y, 1), void()) + update!(y, 2.0) for xval in rand(10) @@ -23,7 +30,8 @@ end @testitem "DeltaNode - Creation with static inputs (all permutations) #2" begin - import ReactiveMP: nodefunction, DeltaMeta, Linearization + using Rocket + import ReactiveMP: nodefunction, DeltaMeta, Linearization, messageout, activate!, RandomVariableActivationOptions, DataVariableActivationOptions foo1(x, y, z) = x * y + z foo2(x, y, z) = x / y - z @@ -35,7 +43,19 @@ end # In this test we attempt to create a lot of possible combinations # of random, data and const inputs to the delta node - create_interfaces(i) = ((:in, randomvar()), (:in, datavar()), (:in, constvar(vals[i]))) + function create_interfaces(i) + r = randomvar() + d = datavar() + c = constvar(vals[i]) + + activate!(r, RandomVariableActivationOptions()) + activate!(d, DataVariableActivationOptions()) + # We need to subscribe on the message, otherwise the recent message + # will not be propagated to the nodefunction + subscribe!(messageout(d, 1), void()) + + return ((:in, r), (:in, d), (:in, c)) + end for x in create_interfaces(1), y in create_interfaces(2), z in create_interfaces(3) in_interfaces = [x, y, z] diff --git a/test/testutilities.jl b/test/testutilities.jl index 38eb7b30b..cba57a881 100644 --- a/test/testutilities.jl +++ b/test/testutilities.jl @@ -8,12 +8,15 @@ function check_stream_updated_once(f, stream) value = Ref{Any}(missing) subscription = subscribe!(stream, (new_value) -> begin if stream_updated - error("Stream was updated more than once") + error("Stream was updated more than once. Recorded value: $(value[]). New value: $(new_value)") end value[] = new_value stream_updated = true end) f() + if !stream_updated + error("Stream was not updated") + end @test stream_updated unsubscribe!(subscription) return value[] diff --git a/test/variables/data_tests.jl b/test/variables/data_tests.jl index c8bca14f2..5f390d4ec 100644 --- a/test/variables/data_tests.jl +++ b/test/variables/data_tests.jl @@ -69,3 +69,130 @@ end end end end + +@testitem "DataVariable: linked variable" begin + using BayesBase + import ReactiveMP: DataVariable, DataVariableActivationOptions, activate!, messageout + + include("../testutilities.jl") + + for fn in (+, *), val1 in 1:3, val2 in 1:3 + @testset begin + var = datavar() + options = DataVariableActivationOptions(true, true, fn, (val1, val2)) + activate!(var, options) + marginal = check_stream_updated_once(getmarginal(var)) + @test getdata(marginal) === PointMass(fn(val1, val2)) + + # `|> take(1)` here is a hack, because the value updates twice + # first its being retranslated from the subscription to the marginal, + # then its going to be recomputed again from the linked datavars + message = check_stream_updated_once(messageout(var, 1) |> take(1)) + @test getdata(message) === PointMass(fn(val1, val2)) + end + + # Just marginal + @testset begin + var = datavar() + options = DataVariableActivationOptions(true, true, fn, (val1, val2)) + activate!(var, options) + marginal = check_stream_updated_once(getmarginal(var)) + @test getdata(marginal) === PointMass(fn(val1, val2)) + end + + # Just message + @testset begin + var = datavar() + options = DataVariableActivationOptions(true, true, fn, (val1, val2)) + activate!(var, options) + message = check_stream_updated_once(messageout(var, 1)) + @test getdata(message) === PointMass(fn(val1, val2)) + end + + @testset begin + var1 = datavar() + activate!(var1, DataVariableActivationOptions(true, false, nothing, nothing)) + + var = datavar() + options = DataVariableActivationOptions(true, true, fn, (var1, val2)) + activate!(var, options) + @test check_stream_not_updated(getmarginal(var)) + + marginal = check_stream_updated_once(getmarginal(var)) do + update!(var1, val1) + end + @test getdata(marginal) === PointMass(fn(val1, val2)) + + # The message should preserve the value from the previous update + # `|> take(1)` here is a hack, because the value updates twice + # first its being retranslated from the subscription to the marginal, + # then its going to be recomputed again from the linked datavars + message = check_stream_updated_once(messageout(var, 1) |> take(1)) + @test getdata(message) === PointMass(fn(val1, val2)) + end + + @testset begin + var2 = datavar() + activate!(var2, DataVariableActivationOptions(true, false, nothing, nothing)) + + var = datavar() + options = DataVariableActivationOptions(true, true, fn, (val1, var2)) + activate!(var, options) + @test check_stream_not_updated(getmarginal(var)) + + marginal = check_stream_updated_once(getmarginal(var)) do + update!(var2, val2) + end + @test getdata(marginal) === PointMass(fn(val1, val2)) + + # `|> take(1)` here is a hack, because the value updates twice + # first its being retranslated from the subscription to the marginal, + # then its going to be recomputed again from the linked datavars + message = check_stream_updated_once(messageout(var, 1) |> take(1)) + @test getdata(message) === PointMass(fn(val1, val2)) + end + + @testset begin + var1 = datavar() + var2 = datavar() + activate!(var1, DataVariableActivationOptions(true, false, nothing, nothing)) + activate!(var2, DataVariableActivationOptions(true, false, nothing, nothing)) + + var = datavar() + options = DataVariableActivationOptions(true, true, fn, (var1, var2)) + activate!(var, options) + @test check_stream_not_updated(getmarginal(var)) + + marginal = check_stream_updated_once(getmarginal(var)) do + update!(var1, val1) + update!(var2, val2) + end + @test getdata(marginal) === PointMass(fn(val1, val2)) + + # `|> take(1)` here is a hack, because the value updates twice + # first its being retranslated from the subscription to the marginal, + # then its going to be recomputed again from the linked datavars + message = check_stream_updated_once(messageout(var, 1) |> take(1)) + @test getdata(message) === PointMass(fn(val1, val2)) + end + + @testset begin + var1 = datavar() + var2 = datavar() + activate!(var1, DataVariableActivationOptions(true, false, nothing, nothing)) + activate!(var2, DataVariableActivationOptions(true, false, nothing, nothing)) + + var = datavar() + options = DataVariableActivationOptions(true, true, fn, (var1, var2)) + activate!(var, options) + @test check_stream_not_updated(getmarginal(var)) + + # We still should be able to update the stream manually + marginal = check_stream_updated_once(getmarginal(var)) do + update!(var, 4) + end + @test getdata(marginal) === PointMass(4) + end + end + +end \ No newline at end of file From 82f8eb422ab0f1bac20ae607874e9f4c25dd632e Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 22 Apr 2024 11:26:20 +0200 Subject: [PATCH 2/4] Improve the performance of the fix, fix weird edge cases --- src/variables/data.jl | 28 ++++++++++------------ test/nodes/predefined/delta/delta_tests.jl | 6 ----- test/variables/data_tests.jl | 23 ++++-------------- 3 files changed, 17 insertions(+), 40 deletions(-) diff --git a/src/variables/data.jl b/src/variables/data.jl index d43676c22..8873c9159 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -1,19 +1,17 @@ export datavar, DataVariable, update!, DataVariableActivationOptions mutable struct DataVariable{M, P} <: AbstractVariable - datastream :: M input_messages :: Vector{MessageObservable{AbstractMessage}} - messageout :: MessageObservable{Message} marginal :: MarginalObservable + messageout :: M prediction :: P end function DataVariable() - datastream = Subject(Message) - messageout = MessageObservable(Message) + messageout = RecentSubject(Message) marginal = MarginalObservable() prediction = MarginalObservable() - return DataVariable(datastream, Vector{MessageObservable{AbstractMessage}}(), messageout, marginal, prediction) + return DataVariable(Vector{MessageObservable{AbstractMessage}}(), marginal, messageout, prediction) end datavar() = DataVariable() @@ -55,16 +53,15 @@ function activate!(datavar::DataVariable, options::DataVariableActivationOptions _setprediction!(datavar, _makeprediction(datavar)) end - # If the variable is not linked to another we simply redirect the message from the datastream - if !options.linked - connect!(datavar.messageout, datavar.datastream) - else + if options.linked # If the variable is linked to another we need to apply a transformation from the linked variables + # and redirect the updates to the `datavar` messageout stream linkvalues = combineLatestUpdates(map(l -> __link_getmarginal(l), options.args)) - linkstream = linkvalues |> map(Message, (args) -> let f = options.transform - return Message(__apply_link(f, getrecent.(args)), false, false, nothing) + linkstream = linkvalues |> map(Any, (args) -> let f = options.transform + return __apply_link(f, getrecent.(args)) end) - connect!(datavar.messageout, merged((datavar.datastream, linkstream))) + # This subscription should unsubscribe automatically when the linked `datavar`s complete + subscribe!(linkstream, (val) -> update!(datavar, val)) end # The marginal stream is always the same as the message out @@ -78,14 +75,15 @@ __link_getmarginal(l::AbstractVariable) = getmarginal(l, IncludeAll()) __link_getmarginal(l::AbstractArray{<:AbstractVariable}) = getmarginals(l, IncludeAll()) __apply_link(f::F, args) where {F} = __apply_link(f, getdata.(args)) -__apply_link(f::F, args::NTuple{N, PointMass}) where {F, N} = PointMass(f(mean.(args)...)) +__apply_link(f::F, args::NTuple{N, PointMass}) where {F, N} = f(mean.(args)...) _getmarginal(datavar::DataVariable) = datavar.marginal _setmarginal!(::DataVariable, observable) = error("It is not possible to set a marginal stream for `DataVariable`") _makemarginal(::DataVariable) = error("It is not possible to make marginal stream for `DataVariable`") -update!(datavar::DataVariable, data) = next!(datavar.datastream, Message(PointMass(data), false, false, nothing)) -update!(datavar::DataVariable, ::Missing) = next!(datavar.datastream, Message(missing, false, false, nothing)) +update!(datavar::DataVariable, data) = update!(datavar, PointMass(data)) +update!(datavar::DataVariable, data::PointMass) = next!(datavar.messageout, Message(data, false, false, nothing)) +update!(datavar::DataVariable, ::Missing) = next!(datavar.messageout, Message(missing, false, false, nothing)) function update!(datavars::AbstractArray{<:DataVariable}, data::AbstractArray) @assert size(datavars) === size(data) """ diff --git a/test/nodes/predefined/delta/delta_tests.jl b/test/nodes/predefined/delta/delta_tests.jl index ac3611681..f969709ce 100644 --- a/test/nodes/predefined/delta/delta_tests.jl +++ b/test/nodes/predefined/delta/delta_tests.jl @@ -16,9 +16,6 @@ activate!(x, RandomVariableActivationOptions()) activate!(y, DataVariableActivationOptions()) - # We need to subscribe on the message, otherwise the recent message - # will not be propagated to the nodefunction - subscribe!(messageout(y, 1), void()) update!(y, 2.0) @@ -50,9 +47,6 @@ end activate!(r, RandomVariableActivationOptions()) activate!(d, DataVariableActivationOptions()) - # We need to subscribe on the message, otherwise the recent message - # will not be propagated to the nodefunction - subscribe!(messageout(d, 1), void()) return ((:in, r), (:in, d), (:in, c)) end diff --git a/test/variables/data_tests.jl b/test/variables/data_tests.jl index 5f390d4ec..5015a76e7 100644 --- a/test/variables/data_tests.jl +++ b/test/variables/data_tests.jl @@ -83,11 +83,7 @@ end activate!(var, options) marginal = check_stream_updated_once(getmarginal(var)) @test getdata(marginal) === PointMass(fn(val1, val2)) - - # `|> take(1)` here is a hack, because the value updates twice - # first its being retranslated from the subscription to the marginal, - # then its going to be recomputed again from the linked datavars - message = check_stream_updated_once(messageout(var, 1) |> take(1)) + message = check_stream_updated_once(messageout(var, 1)) @test getdata(message) === PointMass(fn(val1, val2)) end @@ -122,12 +118,7 @@ end update!(var1, val1) end @test getdata(marginal) === PointMass(fn(val1, val2)) - - # The message should preserve the value from the previous update - # `|> take(1)` here is a hack, because the value updates twice - # first its being retranslated from the subscription to the marginal, - # then its going to be recomputed again from the linked datavars - message = check_stream_updated_once(messageout(var, 1) |> take(1)) + message = check_stream_updated_once(messageout(var, 1)) @test getdata(message) === PointMass(fn(val1, val2)) end @@ -145,10 +136,7 @@ end end @test getdata(marginal) === PointMass(fn(val1, val2)) - # `|> take(1)` here is a hack, because the value updates twice - # first its being retranslated from the subscription to the marginal, - # then its going to be recomputed again from the linked datavars - message = check_stream_updated_once(messageout(var, 1) |> take(1)) + message = check_stream_updated_once(messageout(var, 1)) @test getdata(message) === PointMass(fn(val1, val2)) end @@ -169,10 +157,7 @@ end end @test getdata(marginal) === PointMass(fn(val1, val2)) - # `|> take(1)` here is a hack, because the value updates twice - # first its being retranslated from the subscription to the marginal, - # then its going to be recomputed again from the linked datavars - message = check_stream_updated_once(messageout(var, 1) |> take(1)) + message = check_stream_updated_once(messageout(var, 1)) @test getdata(message) === PointMass(fn(val1, val2)) end From fdad33a1ef50f8aaabadd178541c02b03596aaf9 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 22 Apr 2024 11:32:19 +0200 Subject: [PATCH 3/4] revert unnecessary change --- src/message.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/message.jl b/src/message.jl index 4bb495daf..2e3787be8 100644 --- a/src/message.jl +++ b/src/message.jl @@ -241,13 +241,12 @@ dropproxytype(::Type{<:Message{T}}) where {T} = T ## Message observable -struct MessageObservable{M <: AbstractMessage, S <: AbstractSubject} <: Subscribable{M} - subject :: S +struct MessageObservable{M <: AbstractMessage} <: Subscribable{M} + subject :: Rocket.RecentSubjectInstance{M, Subject{M, AsapScheduler, AsapScheduler}} stream :: LazyObservable{M} end -MessageObservable(::Type{M} = AbstractMessage) where {M} = MessageObservable(RecentSubject(M), lazy(M)) -MessageObservable(subject::AbstractSubject{M}) where {M} = MessageObservable(subject, lazy(M)) +MessageObservable(::Type{M} = AbstractMessage) where {M} = MessageObservable{M}(RecentSubject(M), lazy(M)) Rocket.getrecent(observable::MessageObservable) = Rocket.getrecent(observable.subject) From 3fee1cc36bb70430e53572963b18a9f76dd5991b Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 22 Apr 2024 13:01:20 +0200 Subject: [PATCH 4/4] style: make format --- test/nodes/predefined/delta/delta_tests.jl | 2 +- test/variables/data_tests.jl | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/test/nodes/predefined/delta/delta_tests.jl b/test/nodes/predefined/delta/delta_tests.jl index f969709ce..904b16499 100644 --- a/test/nodes/predefined/delta/delta_tests.jl +++ b/test/nodes/predefined/delta/delta_tests.jl @@ -40,7 +40,7 @@ end # In this test we attempt to create a lot of possible combinations # of random, data and const inputs to the delta node - function create_interfaces(i) + function create_interfaces(i) r = randomvar() d = datavar() c = constvar(vals[i]) diff --git a/test/variables/data_tests.jl b/test/variables/data_tests.jl index 5015a76e7..7a38d5a30 100644 --- a/test/variables/data_tests.jl +++ b/test/variables/data_tests.jl @@ -114,7 +114,7 @@ end activate!(var, options) @test check_stream_not_updated(getmarginal(var)) - marginal = check_stream_updated_once(getmarginal(var)) do + marginal = check_stream_updated_once(getmarginal(var)) do update!(var1, val1) end @test getdata(marginal) === PointMass(fn(val1, val2)) @@ -131,7 +131,7 @@ end activate!(var, options) @test check_stream_not_updated(getmarginal(var)) - marginal = check_stream_updated_once(getmarginal(var)) do + marginal = check_stream_updated_once(getmarginal(var)) do update!(var2, val2) end @test getdata(marginal) === PointMass(fn(val1, val2)) @@ -151,7 +151,7 @@ end activate!(var, options) @test check_stream_not_updated(getmarginal(var)) - marginal = check_stream_updated_once(getmarginal(var)) do + marginal = check_stream_updated_once(getmarginal(var)) do update!(var1, val1) update!(var2, val2) end @@ -173,11 +173,10 @@ end @test check_stream_not_updated(getmarginal(var)) # We still should be able to update the stream manually - marginal = check_stream_updated_once(getmarginal(var)) do + marginal = check_stream_updated_once(getmarginal(var)) do update!(var, 4) end @test getdata(marginal) === PointMass(4) end end - -end \ No newline at end of file +end