diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index e2576982..f6818e15 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -1,9 +1,35 @@ from dataclasses import dataclass, field -from enum import Enum +from enum import IntEnum from typing import Optional -class TBVariant(int, Enum): +class Backward(IntEnum): + """ + See algo.trajectory_balance.TrajectoryBalance for details. + The A variant of `Maxent` and `GSQL` equire the environment to provide $n$. + This is true for sEH but not QM9. + """ + + Uniform = 1 + Free = 2 + Maxent = 3 + MaxentA = 4 + GSQL = 5 + GSQLA = 6 + + +class NLoss(IntEnum): + """See algo.trajectory_balance.TrajectoryBalance for details.""" + + none = 0 + Transition = 1 + SubTB1 = 2 + TermTB1 = 3 + StartTB1 = 4 + TB = 5 + + +class TBVariant(IntEnum): """See algo.trajectory_balance.TrajectoryBalance for details.""" TB = 0 @@ -11,6 +37,21 @@ class TBVariant(int, Enum): DB = 2 +class LossFN(IntEnum): + """ + The loss function to use. + + - GHL: Kaan Gokcesu, Hakan Gokcesu + https://arxiv.org/pdf/2108.12627.pdf, + Note: This can be used as a differentiable version of HUB. + """ + + MSE = 0 + MAE = 1 + HUB = 2 + GHL = 3 + + @dataclass class TBConfig: """Trajectory Balance config. @@ -39,6 +80,16 @@ class TBConfig: The learning rate for the logZ parameter (only relevant when do_subtb is False) Z_lr_decay : float The learning rate decay for the logZ parameter (only relevant when do_subtb is False) + loss_fn: LossFN + The loss function to use + loss_fn_par: float + The loss function parameter in case of Huber loss, it is the delta + n_loss: NLoss + The $n$ loss to use (defaults to NLoss.none i.e., do not learn $n$) + n_loss_multiplier: float + The multiplier for the $n$ loss + backward_policy: Backward + The backward policy to use """ bootstrap_own_reward: bool = False @@ -54,6 +105,11 @@ class TBConfig: Z_learning_rate: float = 1e-4 Z_lr_decay: float = 50_000 cum_subtb: bool = True + loss_fn: LossFN = LossFN.MSE + loss_fn_par: float = 1.0 + n_loss: NLoss = NLoss.none + n_loss_multiplier: float = 1.0 + backward_policy: Backward = Backward.Uniform @dataclass diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index fcd171b4..2b893f33 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -9,7 +9,7 @@ from torch import Tensor from torch_scatter import scatter, scatter_sum -from gflownet.algo.config import TBVariant +from gflownet.algo.config import Backward, LossFN, NLoss, TBVariant from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config from gflownet.envs.graph_building_env import ( @@ -42,7 +42,7 @@ def cross(x: torch.Tensor): return y[None] - shift_right(y)[:, None] -def subTB(v: torch.tensor, x: torch.Tensor): +def subTB(v: torch.Tensor, x: torch.Tensor): r""" Compute the SubTB(1): $\forall i \leq j: D[i,j] = @@ -84,6 +84,9 @@ class TrajectoryBalance(GFNAlgorithm): Note: This is the trajectory version of Detailed Balance (i.e. transitions are not iid, but trajectories are). Empirical results in subsequent papers suggest that DB may be improved by training on iid transitions (sampled from a replay buffer) instead of trajectories. + + - Maxent[A], GSQL[A], TermTB1, StartTB1: Sobhan Mohammadpour, Emmanuel Bengio, Emma Frejinger, Pierre-Luc Bacon + https://arxiv.org/abs/2312.14331 """ def __init__( @@ -115,9 +118,8 @@ def __init__( self.max_nodes = cfg.algo.max_nodes self.length_normalize_losses = cfg.algo.tb.do_length_normalize # Experimental flags - self.reward_loss_is_mae = True - self.tb_loss_is_mae = False - self.tb_loss_is_huber = False + self.reward_loss = self.cfg.loss_fn + self.tb_loss = self.cfg.loss_fn self.mask_invalid_rewards = False self.reward_normalize_losses = False self.sample_temp = 1 @@ -126,6 +128,13 @@ def __init__( # instead give "ABC...Z" as a single input, but grab the logits at every timestep. Only works if using something # like a transformer with causal self-attention. self.model_is_autoregressive = False + assert ( + self.cfg.backward_policy not in [Backward.Maxent, Backward.GSQL] or self.cfg.n_loss != NLoss.none + ), "can't do maxent w/o learning or knowing $n$" + assert self.ctx.has_n() or ( + self.cfg.backward_policy not in [Backward.MaxentA, Backward.GSQLA] + ), "can't do analytical maxent/GSQL w/o knowing $n$" + assert self.cfg.do_predict_n or self.cfg.n_loss == NLoss.none, "`n_loss != NLoss.none` requires `do_predict_n`" self.random_action_prob = [cfg.algo.train_random_action_prob, cfg.algo.valid_random_action_prob] self.graph_sampler = GraphSampler( @@ -372,8 +381,13 @@ def compute_batch_losses( # of length 4, trajectory 1 of length 3, and so on. batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens) # The position of the last graph of each trajectory - final_graph_idx = torch.cumsum(batch.traj_lens, 0) - 1 + traj_cumlen = torch.cumsum(batch.traj_lens, 0) + final_graph_idx = traj_cumlen - 1 + # The position of the first graph of each trajectory + first_graph_idx = shift_right(traj_cumlen) + final_graph_idx_1 = torch.maximum(final_graph_idx - 1, first_graph_idx) + fwd_cat: GraphActionCategorical # Forward pass of the model, returns a GraphActionCategorical representing the forward # policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB). if self.cfg.do_parameterize_p_b: @@ -386,6 +400,12 @@ def compute_batch_losses( # Retreive the reward predictions for the full graphs, # i.e. the final graph of each trajectory log_reward_preds = per_graph_out[final_graph_idx, 0] + if self.cfg.do_predict_n: + log_n_preds = per_graph_out[:, 1] + log_n_preds[first_graph_idx] = 0 + else: + log_n_preds = None + # Compute trajectory balance objective log_Z = model.logZ(cond_info)[:, 0] # Compute the log prob of each action in the trajectory @@ -444,8 +464,49 @@ def compute_batch_losses( log_p_B = batch.log_p_B assert log_p_F.shape == log_p_B.shape + if self.cfg.n_loss == NLoss.TB: + log_traj_n = scatter(log_p_B, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") + n_loss = self._loss(log_traj_n + log_n_preds[final_graph_idx_1]) + else: + n_loss = self.n_loss(log_p_B, log_n_preds, batch.traj_lens) + + if self.ctx.has_n() and self.cfg.do_predict_n: + analytical_maxent_backward = self.analytical_maxent_backward(batch, first_graph_idx) + if self.cfg.do_parameterize_p_b: + analytical_maxent_backward = torch.roll(analytical_maxent_backward, -1, 0) * (1 - batch.is_sink) + else: + analytical_maxent_backward = None + + if self.cfg.backward_policy in [Backward.GSQL, Backward.GSQLA]: + log_p_B = torch.zeros_like(log_p_B) + nzf = torch.maximum(first_graph_idx, final_graph_idx - 1) + if self.cfg.backward_policy == Backward.GSQLA: + log_p_B[nzf] = -batch.log_n + else: + log_p_B[nzf] = -log_n_preds[ + nzf + ] # this is due to the fact that n(s_0)/n(s1) * n(s1)/ n(s2) = n(s_0)/n(s2) = 1 / n(s) + # this is not final_graph_idx because we throw away the last thing + elif self.cfg.backward_policy == Backward.MaxentA: + log_p_B = analytical_maxent_backward + + if self.cfg.do_parameterize_p_b: + # Life is pain, log_p_B is one unit too short for all trajs + + log_p_B_unif = torch.zeros_like(log_p_B) + for i, (s, e) in enumerate(zip(first_graph_idx, traj_cumlen)): + log_p_B_unif[s : e - 1] = batch.log_p_B[s - i : e - 1 - i] + + if self.cfg.backward_policy == Backward.Uniform: + log_p_B = log_p_B_unif + else: + log_p_B_unif = log_p_B + + if self.cfg.backward_policy in [Backward.Maxent, Backward.GSQL]: + log_p_B = log_p_B.detach() # This is the log probability of each trajectory traj_log_p_F = scatter(log_p_F, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") + traj_unif_log_p_B = scatter(log_p_B_unif, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") traj_log_p_B = scatter(log_p_B, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") if self.cfg.variant == TBVariant.SubTB1: @@ -463,7 +524,7 @@ def compute_batch_losses( F_sn = per_graph_out[:, 0] F_sm = per_graph_out[:, 0].roll(-1) F_sm[final_graph_idx] = clip_log_R - transition_losses = (F_sn + log_p_F - F_sm - log_p_B).pow(2) + transition_losses = self._loss(F_sn + log_p_F - F_sm - log_p_B) traj_losses = scatter(transition_losses, batch_idx, dim=0, dim_size=num_trajs, reduce="sum") first_graph_idx = torch.zeros_like(batch.traj_lens) torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) @@ -485,12 +546,7 @@ def compute_batch_losses( epsilon = torch.tensor([self.cfg.epsilon], device=dev).float() numerator = torch.logaddexp(numerator, epsilon) denominator = torch.logaddexp(denominator, epsilon) - if self.tb_loss_is_mae: - traj_losses = abs(numerator - denominator) - elif self.tb_loss_is_huber: - pass # TODO - else: - traj_losses = (numerator - denominator).pow(2) + traj_losses = self._loss(numerator - denominator, self.tb_loss) # Normalize losses by trajectory length if self.length_normalize_losses: @@ -507,15 +563,15 @@ def compute_batch_losses( if self.cfg.bootstrap_own_reward: num_bootstrap = num_bootstrap or len(log_rewards) - if self.reward_loss_is_mae: - reward_losses = abs(log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap]) - else: - reward_losses = (log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap]).pow(2) + reward_losses = self._loss(log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap], self.reward_loss) + reward_loss = reward_losses.mean() * self.cfg.reward_loss_multiplier else: reward_loss = 0 - loss = traj_losses.mean() + reward_loss + n_loss = n_loss.mean() + tb_loss = traj_losses.mean() + loss = tb_loss + reward_loss + self.cfg.n_loss_multiplier * n_loss info = { "offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0, "online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0, @@ -523,11 +579,33 @@ def compute_batch_losses( "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, "invalid_logprob": (invalid_mask * traj_log_p_F).sum() / (invalid_mask.sum() + 1e-4), "invalid_losses": (invalid_mask * traj_losses).sum() / (invalid_mask.sum() + 1e-4), + "backward_vs_unif": (traj_unif_log_p_B - traj_log_p_B).pow(2).mean(), "logZ": log_Z.mean(), "loss": loss.item(), + "n_loss": n_loss, + "tb_loss": tb_loss.item(), + "batch_entropy": -traj_log_p_F.mean(), + "traj_lens": batch.traj_lens.float().mean(), } + if self.ctx.has_n() and self.cfg.do_predict_n: + info["n_loss_pred"] = scatter( + (log_n_preds - batch.log_ns) ** 2, batch_idx, dim=0, dim_size=num_trajs, reduce="sum" + ).mean() + info["n_final_loss"] = torch.mean((log_n_preds[final_graph_idx] - batch.log_n) ** 2) + if self.cfg.do_parameterize_p_b: + info["n_loss_tgsql"] = torch.mean((-batch.log_n - traj_log_p_B) ** 2) + d = analytical_maxent_backward - log_p_B + d = d * d + d[final_graph_idx] = 0 + info["n_loss_maxent"] = scatter(d, batch_idx, dim=0, dim_size=num_trajs, reduce="sum").mean() + return loss, info + def analytical_maxent_backward(self, batch, first_graph_idx): + s = shift_right(batch.log_ns) + s[first_graph_idx] = 0 + return s - batch.log_ns + def _init_subtb(self, dev): r"""Precompute all possible subtrajectory indices that we will use for computing the loss: \sum_{m=1}^{T-1} \sum_{n=m+1}^T @@ -615,9 +693,57 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths): P_B_sums = scatter_sum(P_B[idces + offset], dests) F_start = F[offset : offset + T].repeat_interleave(T - ar[:T]) F_end = F_and_R[fidces] - total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T] + total_loss[ep] = self._loss(F_start - F_end + P_F_sums - P_B_sums).sum() / car[T] + return total_loss + + def n_loss(self, P_N, N, traj_lengths): + dev = traj_lengths.device + num_trajs = len(traj_lengths) + total_loss = torch.zeros(num_trajs, device=dev) + if self.cfg.n_loss == NLoss.none: + return total_loss + assert self.cfg.do_parameterize_p_b + + x = torch.cumsum(traj_lengths, 0) + for ep, (s_idx, e_idx) in enumerate(zip(shift_right(x), x)): + # the last state is the same as the first state + e_idx -= 1 + total_loss[ep] = self._n_loss(self.cfg.n_loss, P_N[s_idx : e_idx - 1], N[s_idx:e_idx]) return total_loss + def _loss(self, x, loss_fn=None): + if loss_fn is None: + loss_fn = self.cfg.loss_fn + if loss_fn == LossFN.MSE: + return x * x + elif loss_fn == LossFN.MAE: + return torch.abs(x) + elif loss_fn == LossFN.HUB: + ax = torch.abs(x) + d = self.cfg.loss_fn_par + return torch.where(ax < 1, 0.5 * x * x / d, ax / d - 0.5 / d) + elif loss_fn == LossFN.GHL: + ax = self.cfg.loss_fn_par * x + return torch.logaddexp(ax, -ax) - np.log(2) + else: + raise NotImplementedError() + + def _n_loss(self, method, P_N, N): + n = len(N) + if method == NLoss.SubTB1: + return self._loss(subTB(N, -P_N)).sum() / (n * n - n) * 2 + elif method == NLoss.TermTB1: + return self._loss(subTB(N, -P_N)[:, 0]).mean() + elif method == NLoss.StartTB1: + # return self._loss(subTB(N, -P_N)[0, :]).mean() + return self._loss(N[1:] + torch.cumsum(P_N, -1)).mean() + elif method == NLoss.TB: + return self._loss(P_N.sum() + N[-1]) + elif method == NLoss.Transition: + return self._loss(N[1:] + P_N - N[:-1]).mean() + else: + raise NotImplementedError() + def subtb_cum(self, P_F, P_B, F, R, traj_lengths): """ Calcualte the subTB(1) loss (all arguments on log-scale) using dynamic programming. @@ -636,5 +762,5 @@ def subtb_cum(self, P_F, P_B, F, R, traj_lengths): n = e_idx - s_idx fr = torch.cat([F[s_idx:e_idx], torch.tensor([R[ep]], device=F.device)]) p = pdiff[s_idx:e_idx] - total_loss[ep] = subTB(fr, p).pow(2).sum() / (n * n + n) * 2 + total_loss[ep] = self._loss(subTB(fr, p)).sum() / (n * n + n) * 2 return total_loss diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index d78a2a7f..b8a2c7e1 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -225,7 +225,7 @@ def create_batch(self, trajs, batch_info): if "focus_dir" in trajs[0]: batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) - if self.ctx.has_n(): # Does this go somewhere else? Require a flag? Might not be cheap to compute + if self.ctx.has_n() and self.cfg.algo.tb.do_predict_n: log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index e64d642d..d9da0386 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -211,7 +211,7 @@ def main(): config.num_training_steps = 1_00 config.validate_every = 20 config.num_final_gen_steps = 10 - config.num_workers = 8 + config.num_workers = 1 config.opt.lr_decay = 20_000 config.algo.sampling_tau = 0.99 config.cond.temperature.sample_dist = "uniform"