Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add maxent GFNs #124

Merged
merged 4 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,57 @@
from dataclasses import dataclass, field
from enum import Enum
from enum import IntEnum
from typing import Optional


class TBVariant(int, Enum):
class Backward(IntEnum):
"""
See algo.trajectory_balance.TrajectoryBalance for details.
The A variant of `Maxent` and `GSQL` equire the environment to provide $n$.
This is true for sEH but not QM9.
"""

Uniform = 1
Free = 2
Maxent = 3
MaxentA = 4
GSQL = 5
GSQLA = 6


class NLoss(IntEnum):
"""See algo.trajectory_balance.TrajectoryBalance for details."""

none = 0
Transition = 1
SubTB1 = 2
TermTB1 = 3
StartTB1 = 4
TB = 5


class TBVariant(IntEnum):
"""See algo.trajectory_balance.TrajectoryBalance for details."""

TB = 0
SubTB1 = 1
DB = 2


class LossFN(IntEnum):
"""
The loss function to use.

- GHL: Kaan Gokcesu, Hakan Gokcesu
https://arxiv.org/pdf/2108.12627.pdf,
Note: This can be used as a differentiable version of HUB.
"""

MSE = 0
MAE = 1
HUB = 2
GHL = 3


@dataclass
class TBConfig:
"""Trajectory Balance config.
Expand Down Expand Up @@ -39,6 +80,16 @@ class TBConfig:
The learning rate for the logZ parameter (only relevant when do_subtb is False)
Z_lr_decay : float
The learning rate decay for the logZ parameter (only relevant when do_subtb is False)
loss_fn: LossFN
The loss function to use
loss_fn_par: float
The loss function parameter in case of Huber loss, it is the delta
n_loss: NLoss
The $n$ loss to use (defaults to NLoss.none i.e., do not learn $n$)
n_loss_multiplier: float
The multiplier for the $n$ loss
backward_policy: Backward
The backward policy to use
"""

bootstrap_own_reward: bool = False
Expand All @@ -54,6 +105,11 @@ class TBConfig:
Z_learning_rate: float = 1e-4
Z_lr_decay: float = 50_000
cum_subtb: bool = True
loss_fn: LossFN = LossFN.MSE
loss_fn_par: float = 1.0
n_loss: NLoss = NLoss.none
n_loss_multiplier: float = 1.0
backward_policy: Backward = Backward.Uniform


@dataclass
Expand Down
166 changes: 146 additions & 20 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import Tensor
from torch_scatter import scatter, scatter_sum

from gflownet.algo.config import TBVariant
from gflownet.algo.config import Backward, LossFN, NLoss, TBVariant
from gflownet.algo.graph_sampling import GraphSampler
from gflownet.config import Config
from gflownet.envs.graph_building_env import (
Expand Down Expand Up @@ -42,7 +42,7 @@ def cross(x: torch.Tensor):
return y[None] - shift_right(y)[:, None]


def subTB(v: torch.tensor, x: torch.Tensor):
def subTB(v: torch.Tensor, x: torch.Tensor):
r"""
Compute the SubTB(1):
$\forall i \leq j: D[i,j] =
Expand Down Expand Up @@ -84,6 +84,9 @@ class TrajectoryBalance(GFNAlgorithm):
Note: This is the trajectory version of Detailed Balance (i.e. transitions are not iid, but trajectories are).
Empirical results in subsequent papers suggest that DB may be improved by training on iid transitions (sampled from
a replay buffer) instead of trajectories.

- Maxent[A], GSQL[A], TermTB1, StartTB1: Sobhan Mohammadpour, Emmanuel Bengio, Emma Frejinger, Pierre-Luc Bacon
https://arxiv.org/abs/2312.14331
"""

def __init__(
Expand Down Expand Up @@ -115,9 +118,8 @@ def __init__(
self.max_nodes = cfg.algo.max_nodes
self.length_normalize_losses = cfg.algo.tb.do_length_normalize
# Experimental flags
self.reward_loss_is_mae = True
self.tb_loss_is_mae = False
self.tb_loss_is_huber = False
self.reward_loss = self.cfg.loss_fn
self.tb_loss = self.cfg.loss_fn
self.mask_invalid_rewards = False
self.reward_normalize_losses = False
self.sample_temp = 1
Expand All @@ -126,6 +128,13 @@ def __init__(
# instead give "ABC...Z" as a single input, but grab the logits at every timestep. Only works if using something
# like a transformer with causal self-attention.
self.model_is_autoregressive = False
assert (
self.cfg.backward_policy not in [Backward.Maxent, Backward.GSQL] or self.cfg.n_loss != NLoss.none
), "can't do maxent w/o learning or knowing $n$"
assert self.ctx.has_n() or (
self.cfg.backward_policy not in [Backward.MaxentA, Backward.GSQLA]
), "can't do analytical maxent/GSQL w/o knowing $n$"
assert self.cfg.do_predict_n or self.cfg.n_loss == NLoss.none, "`n_loss != NLoss.none` requires `do_predict_n`"
self.random_action_prob = [cfg.algo.train_random_action_prob, cfg.algo.valid_random_action_prob]

self.graph_sampler = GraphSampler(
Expand Down Expand Up @@ -372,8 +381,13 @@ def compute_batch_losses(
# of length 4, trajectory 1 of length 3, and so on.
batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens)
# The position of the last graph of each trajectory
final_graph_idx = torch.cumsum(batch.traj_lens, 0) - 1
traj_cumlen = torch.cumsum(batch.traj_lens, 0)
final_graph_idx = traj_cumlen - 1
# The position of the first graph of each trajectory
first_graph_idx = shift_right(traj_cumlen)
final_graph_idx_1 = torch.maximum(final_graph_idx - 1, first_graph_idx)

fwd_cat: GraphActionCategorical
# Forward pass of the model, returns a GraphActionCategorical representing the forward
# policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB).
if self.cfg.do_parameterize_p_b:
Expand All @@ -386,6 +400,12 @@ def compute_batch_losses(
# Retreive the reward predictions for the full graphs,
# i.e. the final graph of each trajectory
log_reward_preds = per_graph_out[final_graph_idx, 0]
if self.cfg.do_predict_n:
log_n_preds = per_graph_out[:, 1]
log_n_preds[first_graph_idx] = 0
else:
log_n_preds = None

# Compute trajectory balance objective
log_Z = model.logZ(cond_info)[:, 0]
# Compute the log prob of each action in the trajectory
Expand Down Expand Up @@ -444,8 +464,49 @@ def compute_batch_losses(
log_p_B = batch.log_p_B
assert log_p_F.shape == log_p_B.shape

if self.cfg.n_loss == NLoss.TB:
log_traj_n = scatter(log_p_B, batch_idx, dim=0, dim_size=num_trajs, reduce="sum")
n_loss = self._loss(log_traj_n + log_n_preds[final_graph_idx_1])
else:
n_loss = self.n_loss(log_p_B, log_n_preds, batch.traj_lens)

if self.ctx.has_n() and self.cfg.do_predict_n:
analytical_maxent_backward = self.analytical_maxent_backward(batch, first_graph_idx)
if self.cfg.do_parameterize_p_b:
analytical_maxent_backward = torch.roll(analytical_maxent_backward, -1, 0) * (1 - batch.is_sink)
else:
analytical_maxent_backward = None

if self.cfg.backward_policy in [Backward.GSQL, Backward.GSQLA]:
log_p_B = torch.zeros_like(log_p_B)
nzf = torch.maximum(first_graph_idx, final_graph_idx - 1)
if self.cfg.backward_policy == Backward.GSQLA:
log_p_B[nzf] = -batch.log_n
else:
log_p_B[nzf] = -log_n_preds[
nzf
] # this is due to the fact that n(s_0)/n(s1) * n(s1)/ n(s2) = n(s_0)/n(s2) = 1 / n(s)
# this is not final_graph_idx because we throw away the last thing
elif self.cfg.backward_policy == Backward.MaxentA:
log_p_B = analytical_maxent_backward

if self.cfg.do_parameterize_p_b:
# Life is pain, log_p_B is one unit too short for all trajs

log_p_B_unif = torch.zeros_like(log_p_B)
for i, (s, e) in enumerate(zip(first_graph_idx, traj_cumlen)):
log_p_B_unif[s : e - 1] = batch.log_p_B[s - i : e - 1 - i]

if self.cfg.backward_policy == Backward.Uniform:
log_p_B = log_p_B_unif
else:
log_p_B_unif = log_p_B

if self.cfg.backward_policy in [Backward.Maxent, Backward.GSQL]:
log_p_B = log_p_B.detach()
# This is the log probability of each trajectory
traj_log_p_F = scatter(log_p_F, batch_idx, dim=0, dim_size=num_trajs, reduce="sum")
traj_unif_log_p_B = scatter(log_p_B_unif, batch_idx, dim=0, dim_size=num_trajs, reduce="sum")
traj_log_p_B = scatter(log_p_B, batch_idx, dim=0, dim_size=num_trajs, reduce="sum")

if self.cfg.variant == TBVariant.SubTB1:
Expand All @@ -463,7 +524,7 @@ def compute_batch_losses(
F_sn = per_graph_out[:, 0]
F_sm = per_graph_out[:, 0].roll(-1)
F_sm[final_graph_idx] = clip_log_R
transition_losses = (F_sn + log_p_F - F_sm - log_p_B).pow(2)
transition_losses = self._loss(F_sn + log_p_F - F_sm - log_p_B)
traj_losses = scatter(transition_losses, batch_idx, dim=0, dim_size=num_trajs, reduce="sum")
first_graph_idx = torch.zeros_like(batch.traj_lens)
torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:])
Expand All @@ -485,12 +546,7 @@ def compute_batch_losses(
epsilon = torch.tensor([self.cfg.epsilon], device=dev).float()
numerator = torch.logaddexp(numerator, epsilon)
denominator = torch.logaddexp(denominator, epsilon)
if self.tb_loss_is_mae:
traj_losses = abs(numerator - denominator)
elif self.tb_loss_is_huber:
pass # TODO
else:
traj_losses = (numerator - denominator).pow(2)
traj_losses = self._loss(numerator - denominator, self.tb_loss)

# Normalize losses by trajectory length
if self.length_normalize_losses:
Expand All @@ -507,27 +563,49 @@ def compute_batch_losses(

if self.cfg.bootstrap_own_reward:
num_bootstrap = num_bootstrap or len(log_rewards)
if self.reward_loss_is_mae:
reward_losses = abs(log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap])
else:
reward_losses = (log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap]).pow(2)
reward_losses = self._loss(log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap], self.reward_loss)

reward_loss = reward_losses.mean() * self.cfg.reward_loss_multiplier
else:
reward_loss = 0

loss = traj_losses.mean() + reward_loss
n_loss = n_loss.mean()
tb_loss = traj_losses.mean()
loss = tb_loss + reward_loss + self.cfg.n_loss_multiplier * n_loss
info = {
"offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0,
"online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0,
"reward_loss": reward_loss,
"invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0,
"invalid_logprob": (invalid_mask * traj_log_p_F).sum() / (invalid_mask.sum() + 1e-4),
"invalid_losses": (invalid_mask * traj_losses).sum() / (invalid_mask.sum() + 1e-4),
"backward_vs_unif": (traj_unif_log_p_B - traj_log_p_B).pow(2).mean(),
"logZ": log_Z.mean(),
"loss": loss.item(),
"n_loss": n_loss,
"tb_loss": tb_loss.item(),
"batch_entropy": -traj_log_p_F.mean(),
"traj_lens": batch.traj_lens.float().mean(),
}
if self.ctx.has_n() and self.cfg.do_predict_n:
info["n_loss_pred"] = scatter(
(log_n_preds - batch.log_ns) ** 2, batch_idx, dim=0, dim_size=num_trajs, reduce="sum"
).mean()
info["n_final_loss"] = torch.mean((log_n_preds[final_graph_idx] - batch.log_n) ** 2)
if self.cfg.do_parameterize_p_b:
info["n_loss_tgsql"] = torch.mean((-batch.log_n - traj_log_p_B) ** 2)
d = analytical_maxent_backward - log_p_B
d = d * d
d[final_graph_idx] = 0
info["n_loss_maxent"] = scatter(d, batch_idx, dim=0, dim_size=num_trajs, reduce="sum").mean()

return loss, info

def analytical_maxent_backward(self, batch, first_graph_idx):
s = shift_right(batch.log_ns)
s[first_graph_idx] = 0
return s - batch.log_ns

def _init_subtb(self, dev):
r"""Precompute all possible subtrajectory indices that we will use for computing the loss:
\sum_{m=1}^{T-1} \sum_{n=m+1}^T
Expand Down Expand Up @@ -615,9 +693,57 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths):
P_B_sums = scatter_sum(P_B[idces + offset], dests)
F_start = F[offset : offset + T].repeat_interleave(T - ar[:T])
F_end = F_and_R[fidces]
total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T]
total_loss[ep] = self._loss(F_start - F_end + P_F_sums - P_B_sums).sum() / car[T]
return total_loss

def n_loss(self, P_N, N, traj_lengths):
dev = traj_lengths.device
num_trajs = len(traj_lengths)
total_loss = torch.zeros(num_trajs, device=dev)
if self.cfg.n_loss == NLoss.none:
return total_loss
assert self.cfg.do_parameterize_p_b

x = torch.cumsum(traj_lengths, 0)
for ep, (s_idx, e_idx) in enumerate(zip(shift_right(x), x)):
# the last state is the same as the first state
e_idx -= 1
total_loss[ep] = self._n_loss(self.cfg.n_loss, P_N[s_idx : e_idx - 1], N[s_idx:e_idx])
return total_loss

def _loss(self, x, loss_fn=None):
if loss_fn is None:
loss_fn = self.cfg.loss_fn
if loss_fn == LossFN.MSE:
return x * x
elif loss_fn == LossFN.MAE:
return torch.abs(x)
elif loss_fn == LossFN.HUB:
ax = torch.abs(x)
d = self.cfg.loss_fn_par
return torch.where(ax < 1, 0.5 * x * x / d, ax / d - 0.5 / d)
elif loss_fn == LossFN.GHL:
ax = self.cfg.loss_fn_par * x
return torch.logaddexp(ax, -ax) - np.log(2)
else:
raise NotImplementedError()

def _n_loss(self, method, P_N, N):
n = len(N)
if method == NLoss.SubTB1:
return self._loss(subTB(N, -P_N)).sum() / (n * n - n) * 2
elif method == NLoss.TermTB1:
return self._loss(subTB(N, -P_N)[:, 0]).mean()
elif method == NLoss.StartTB1:
# return self._loss(subTB(N, -P_N)[0, :]).mean()
return self._loss(N[1:] + torch.cumsum(P_N, -1)).mean()
elif method == NLoss.TB:
return self._loss(P_N.sum() + N[-1])
elif method == NLoss.Transition:
return self._loss(N[1:] + P_N - N[:-1]).mean()
else:
raise NotImplementedError()

def subtb_cum(self, P_F, P_B, F, R, traj_lengths):
"""
Calcualte the subTB(1) loss (all arguments on log-scale) using dynamic programming.
Expand All @@ -636,5 +762,5 @@ def subtb_cum(self, P_F, P_B, F, R, traj_lengths):
n = e_idx - s_idx
fr = torch.cat([F[s_idx:e_idx], torch.tensor([R[ep]], device=F.device)])
p = pdiff[s_idx:e_idx]
total_loss[ep] = subTB(fr, p).pow(2).sum() / (n * n + n) * 2
total_loss[ep] = self._loss(subTB(fr, p)).sum() / (n * n + n) * 2
return total_loss
2 changes: 1 addition & 1 deletion src/gflownet/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def create_batch(self, trajs, batch_info):
if "focus_dir" in trajs[0]:
batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs])

if self.ctx.has_n(): # Does this go somewhere else? Require a flag? Might not be cheap to compute
if self.ctx.has_n() and self.cfg.algo.tb.do_predict_n:
log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs]
batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32)
batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/tasks/seh_frag.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def main():
config.num_training_steps = 1_00
config.validate_every = 20
config.num_final_gen_steps = 10
config.num_workers = 8
config.num_workers = 1
config.opt.lr_decay = 20_000
config.algo.sampling_tau = 0.99
config.cond.temperature.sample_dist = "uniform"
Expand Down
Loading