From d9ef43f94de3b38e6f1163b979a30f6c094f40a6 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Sun, 13 Jul 2025 22:31:55 +0200 Subject: [PATCH 1/2] Doob UX tweaks --- src/doob.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/doob.jl b/src/doob.jl index 2045417..1a6e5a4 100644 --- a/src/doob.jl +++ b/src/doob.jl @@ -14,7 +14,8 @@ DoobMatchingFlow(P::DiscreteProcess, onescale::Bool) = DoobMatchingFlow(P, onesc onescale(P::DoobMatchingFlow,t) = P.onescale ? (1 .- t) : eltype(t)(1) mulexpand(t,x) = expand(t, ndims(x)) .* x -Flowfusion.bridge(p::DoobMatchingFlow, x0::DiscreteState{<:AbstractArray{<:Signed}}, x1::DiscreteState{<:AbstractArray{<:Signed}}, t) = bridge(p.P, x0, x1, t) +#We could consider making this preserve one-hotness: +Flowfusion.bridge(p::DoobMatchingFlow, x0::DiscreteState, x1::DiscreteState, t) = bridge(p.P, x0, x1, t) function fallback_doob(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState; delta = eltype(t)(1e-5)) return (tensor(forward(Xt, P, delta) ⊙ backward(X1, P, (1 .- t) .- delta)) .- tensor(onehot(Xt))) ./ delta; @@ -46,7 +47,7 @@ function rate_constraint(Xt, X̂₁, f) return posQt .+ diagQt end -function velo_step(P, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, delta_t, log_velocity, scale) +function velo_step(P, Xₜ::DiscreteState, delta_t, log_velocity, scale) ohXₜ = onehot(Xₜ) velocity = rate_constraint(tensor(ohXₜ), log_velocity, P.transform) .* scale newXₜ = CategoricalLikelihood(eltype(delta_t).(tensor(ohXₜ) .+ (delta_t .* velocity))) @@ -54,12 +55,12 @@ function velo_step(P, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, delta_t, l return rand(newXₜ) end -step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁::Flowfusion.Guide, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁.H, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁.H))) -step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁))) +step(P::DoobMatchingFlow, Xₜ::DiscreteState, veloX̂₁::Flowfusion.Guide, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁.H, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁.H))) +step(P::DoobMatchingFlow, Xₜ::DiscreteState, veloX̂₁, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁))) function cgm_dloss(P, Xt, X̂₁, doobX₁) Qt = P.transform(X̂₁) return sum((1 .- Xt) .* (Qt .- xlogy.(doobX₁, Qt)), dims = 1) #<- note, diagonals ignored; implicit zero sum end -floss(P::Flowfusion.fbu(DoobMatchingFlow), Xt::Flowfusion.msu(DiscreteState), X̂₁, X₁::Guide, c) = Flowfusion.scaledmaskedmean(cgm_dloss(P, tensor(Xt), tensor(X̂₁), X₁.H), c, Flowfusion.getlmask(X₁)) \ No newline at end of file +floss(P::Flowfusion.fbu(DoobMatchingFlow), Xt::Flowfusion.msu(DiscreteState), X̂₁, X₁::Guide, c) = Flowfusion.scaledmaskedmean(cgm_dloss(P, tensor(Xt), tensor(X̂₁), X₁.H), c, Flowfusion.getlmask(X₁)) From dabcf07d53eacb84838f45406edfafb92017573c Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Sun, 13 Jul 2025 22:48:36 +0200 Subject: [PATCH 2/2] Adding extra bridge method --- src/doob.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/doob.jl b/src/doob.jl index 1a6e5a4..293e875 100644 --- a/src/doob.jl +++ b/src/doob.jl @@ -15,7 +15,8 @@ onescale(P::DoobMatchingFlow,t) = P.onescale ? (1 .- t) : eltype(t)(1) mulexpand(t,x) = expand(t, ndims(x)) .* x #We could consider making this preserve one-hotness: -Flowfusion.bridge(p::DoobMatchingFlow, x0::DiscreteState, x1::DiscreteState, t) = bridge(p.P, x0, x1, t) +bridge(p::DoobMatchingFlow, x0::DiscreteState, x1::DiscreteState, t) = bridge(p.P, x0, x1, t) +bridge(p::DoobMatchingFlow, x0::DiscreteState, x1::DiscreteState, t0, t) = bridge(p.P, x0, x1, t0, t) function fallback_doob(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState; delta = eltype(t)(1e-5)) return (tensor(forward(Xt, P, delta) ⊙ backward(X1, P, (1 .- t) .- delta)) .- tensor(onehot(Xt))) ./ delta;