Skip to content

Commit

Permalink
Merge pull request #393 from ReactiveBayes/linked-datavar-msg-fix
Browse files Browse the repository at this point in the history
Fix for linked datavars
  • Loading branch information
bvdmitri authored Apr 23, 2024
2 parents 35563e1 + 3fee1cc commit 171417b
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 16 deletions.
30 changes: 18 additions & 12 deletions src/variables/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end

function DataVariable()
messageout = RecentSubject(Message)
marginal = MarginalObservable()
marginal = MarginalObservable()
prediction = MarginalObservable()
return DataVariable(Vector{MessageObservable{AbstractMessage}}(), marginal, messageout, prediction)
end
Expand Down Expand Up @@ -46,22 +46,27 @@ struct DataVariableActivationOptions
args
end

DataVariableActivationOptions() = DataVariableActivationOptions(false, false, nothing, nothing)

function activate!(datavar::DataVariable, options::DataVariableActivationOptions)
if true # options.prediction
if options.prediction
_setprediction!(datavar, _makeprediction(datavar))
end

# If the variable is not linked to another we simply redirect the message as a marginal
if !options.linked
connect!(datavar.marginal, datavar.messageout |> map(Marginal, as_marginal))
else
if options.linked
# If the variable is linked to another we need to apply a transformation from the linked variables
# and redirect the updates to the `datavar` messageout stream
linkvalues = combineLatestUpdates(map(l -> __link_getmarginal(l), options.args))
linkstream = linkvalues |> map(Marginal, (args) -> let f = options.transform
return Marginal(__apply_link(f, getrecent.(args)), false, false, nothing)
linkstream = linkvalues |> map(Any, (args) -> let f = options.transform
return __apply_link(f, getrecent.(args))
end)
connect!(datavar.marginal, linkstream)
# This subscription should unsubscribe automatically when the linked `datavar`s complete
subscribe!(linkstream, (val) -> update!(datavar, val))
end

# The marginal stream is always the same as the message out
connect!(datavar.marginal, datavar.messageout |> map(Marginal, as_marginal))

return nothing
end

Expand All @@ -70,14 +75,15 @@ __link_getmarginal(l::AbstractVariable) = getmarginal(l, IncludeAll())
__link_getmarginal(l::AbstractArray{<:AbstractVariable}) = getmarginals(l, IncludeAll())

__apply_link(f::F, args) where {F} = __apply_link(f, getdata.(args))
__apply_link(f::F, args::NTuple{N, PointMass}) where {F, N} = PointMass(f(mean.(args)...))
__apply_link(f::F, args::NTuple{N, PointMass}) where {F, N} = f(mean.(args)...)

_getmarginal(datavar::DataVariable) = datavar.marginal
_setmarginal!(::DataVariable, observable) = error("It is not possible to set a marginal stream for `DataVariable`")
_makemarginal(::DataVariable) = error("It is not possible to make marginal stream for `DataVariable`")

update!(datavar::DataVariable, data) = next!(messageout(datavar, 1), Message(PointMass(data), false, false, nothing))
update!(datavar::DataVariable, ::Missing) = next!(messageout(datavar, 1), Message(missing, false, false, nothing))
update!(datavar::DataVariable, data) = update!(datavar, PointMass(data))
update!(datavar::DataVariable, data::PointMass) = next!(datavar.messageout, Message(data, false, false, nothing))
update!(datavar::DataVariable, ::Missing) = next!(datavar.messageout, Message(missing, false, false, nothing))

function update!(datavars::AbstractArray{<:DataVariable}, data::AbstractArray)
@assert size(datavars) === size(data) """
Expand Down
20 changes: 17 additions & 3 deletions test/nodes/predefined/delta/delta_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

@testitem "DeltaNode - creation with static inputs (simple case) #1" begin
import ReactiveMP: nodefunction, DeltaMeta, Linearization
using Rocket
import ReactiveMP: nodefunction, DeltaMeta, Linearization, messageout, activate!, RandomVariableActivationOptions, DataVariableActivationOptions

foo(x, y, z) = x * y + z

Expand All @@ -13,6 +14,9 @@
node = factornode(foo, [(:out, out), (:in, x), (:in, y), (:in, z)], ((1, 2, 3, 4),))
meta = DeltaMeta(method = Linearization())

activate!(x, RandomVariableActivationOptions())
activate!(y, DataVariableActivationOptions())

update!(y, 2.0)

for xval in rand(10)
Expand All @@ -23,7 +27,8 @@
end

@testitem "DeltaNode - Creation with static inputs (all permutations) #2" begin
import ReactiveMP: nodefunction, DeltaMeta, Linearization
using Rocket
import ReactiveMP: nodefunction, DeltaMeta, Linearization, messageout, activate!, RandomVariableActivationOptions, DataVariableActivationOptions

foo1(x, y, z) = x * y + z
foo2(x, y, z) = x / y - z
Expand All @@ -35,7 +40,16 @@ end

# In this test we attempt to create a lot of possible combinations
# of random, data and const inputs to the delta node
create_interfaces(i) = ((:in, randomvar()), (:in, datavar()), (:in, constvar(vals[i])))
function create_interfaces(i)
r = randomvar()
d = datavar()
c = constvar(vals[i])

activate!(r, RandomVariableActivationOptions())
activate!(d, DataVariableActivationOptions())

return ((:in, r), (:in, d), (:in, c))
end

for x in create_interfaces(1), y in create_interfaces(2), z in create_interfaces(3)
in_interfaces = [x, y, z]
Expand Down
5 changes: 4 additions & 1 deletion test/testutilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ function check_stream_updated_once(f, stream)
value = Ref{Any}(missing)
subscription = subscribe!(stream, (new_value) -> begin
if stream_updated
error("Stream was updated more than once")
error("Stream was updated more than once. Recorded value: $(value[]). New value: $(new_value)")
end
value[] = new_value
stream_updated = true
end)
f()
if !stream_updated
error("Stream was not updated")
end
@test stream_updated
unsubscribe!(subscription)
return value[]
Expand Down
111 changes: 111 additions & 0 deletions test/variables/data_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,114 @@ end
end
end
end

@testitem "DataVariable: linked variable" begin
using BayesBase
import ReactiveMP: DataVariable, DataVariableActivationOptions, activate!, messageout

include("../testutilities.jl")

for fn in (+, *), val1 in 1:3, val2 in 1:3
@testset begin
var = datavar()
options = DataVariableActivationOptions(true, true, fn, (val1, val2))
activate!(var, options)
marginal = check_stream_updated_once(getmarginal(var))
@test getdata(marginal) === PointMass(fn(val1, val2))
message = check_stream_updated_once(messageout(var, 1))
@test getdata(message) === PointMass(fn(val1, val2))
end

# Just marginal
@testset begin
var = datavar()
options = DataVariableActivationOptions(true, true, fn, (val1, val2))
activate!(var, options)
marginal = check_stream_updated_once(getmarginal(var))
@test getdata(marginal) === PointMass(fn(val1, val2))
end

# Just message
@testset begin
var = datavar()
options = DataVariableActivationOptions(true, true, fn, (val1, val2))
activate!(var, options)
message = check_stream_updated_once(messageout(var, 1))
@test getdata(message) === PointMass(fn(val1, val2))
end

@testset begin
var1 = datavar()
activate!(var1, DataVariableActivationOptions(true, false, nothing, nothing))

var = datavar()
options = DataVariableActivationOptions(true, true, fn, (var1, val2))
activate!(var, options)
@test check_stream_not_updated(getmarginal(var))

marginal = check_stream_updated_once(getmarginal(var)) do
update!(var1, val1)
end
@test getdata(marginal) === PointMass(fn(val1, val2))
message = check_stream_updated_once(messageout(var, 1))
@test getdata(message) === PointMass(fn(val1, val2))
end

@testset begin
var2 = datavar()
activate!(var2, DataVariableActivationOptions(true, false, nothing, nothing))

var = datavar()
options = DataVariableActivationOptions(true, true, fn, (val1, var2))
activate!(var, options)
@test check_stream_not_updated(getmarginal(var))

marginal = check_stream_updated_once(getmarginal(var)) do
update!(var2, val2)
end
@test getdata(marginal) === PointMass(fn(val1, val2))

message = check_stream_updated_once(messageout(var, 1))
@test getdata(message) === PointMass(fn(val1, val2))
end

@testset begin
var1 = datavar()
var2 = datavar()
activate!(var1, DataVariableActivationOptions(true, false, nothing, nothing))
activate!(var2, DataVariableActivationOptions(true, false, nothing, nothing))

var = datavar()
options = DataVariableActivationOptions(true, true, fn, (var1, var2))
activate!(var, options)
@test check_stream_not_updated(getmarginal(var))

marginal = check_stream_updated_once(getmarginal(var)) do
update!(var1, val1)
update!(var2, val2)
end
@test getdata(marginal) === PointMass(fn(val1, val2))

message = check_stream_updated_once(messageout(var, 1))
@test getdata(message) === PointMass(fn(val1, val2))
end

@testset begin
var1 = datavar()
var2 = datavar()
activate!(var1, DataVariableActivationOptions(true, false, nothing, nothing))
activate!(var2, DataVariableActivationOptions(true, false, nothing, nothing))

var = datavar()
options = DataVariableActivationOptions(true, true, fn, (var1, var2))
activate!(var, options)
@test check_stream_not_updated(getmarginal(var))

# We still should be able to update the stream manually
marginal = check_stream_updated_once(getmarginal(var)) do
update!(var, 4)
end
@test getdata(marginal) === PointMass(4)
end
end
end

0 comments on commit 171417b

Please sign in to comment.