Skip to content

MoEs & muP #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: sgd_with_adam
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ parser.add_argument('--compile', action='store_true') # if true then model is co
parser.add_argument('--rmsnorm_eps', default=1e-5, type=float) # used by the llama model
parser.add_argument('--multiple_of', default=256, type=int) # used by the llama model make SwiGLU hidden layer size multiple of large power of 2
parser.add_argument('--n_kv_head', default=None, type=int) # for Adam-mini
parser.add_argument('--moe', action='store_true')
parser.add_argument('--moe_routing', default='standard_gating', type=str, choices=['standard_gating', 'expert_choice'],)
parser.add_argument('--moe_num_experts', default=8, type=int)
parser.add_argument('--capacity_factor', default=2.0, type=float) # only used for expert choice routing
parser.add_argument('--moe_num_shared_experts', default=0, type=int) # deepseek routing, experts that are always active
parser.add_argument('--moe_router_loss', default='load_balancing_z_loss', type=str, choices=['entropy', 'load_balancing_only', 'load_balancing_z_loss'],)
parser.add_argument('--moe_num_experts_per_tok', default=2, type=int)
parser.add_argument('--moe_entropy_loss_factor', default=0.01, type=float)
parser.add_argument('--moe_aux_loss_factor', default=0.1, type=float)
parser.add_argument('--moe_z_loss_factor', default=0.01, type=float)
parser.add_argument('--moe_softmax_order', type=str, default='topk_softmax', choices=['softmax_topk', 'topk_softmax'],)
parser.add_argument('--plot_router_logits', action='store_true')
# Checkpointing
parser.add_argument('--results_base_folder', default='./exps', type=str)
parser.add_argument('--permanent_ckpt_interval', default=0, type=int)
Expand Down
31 changes: 31 additions & 0 deletions src/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,5 +260,36 @@ def parse_args(base_parser, args, namespace):
parser.add_argument("--bias", default=False, type=bool)
parser.add_argument("--compile", action="store_true")
parser.add_argument("--mlp_dim_exp_factor", default=1.0, type=float)
parser.add_argument("--moe", action="store_true")
parser.add_argument(
"--moe_routing",
default="standard_gating",
type=str,
choices=["standard_gating", "expert_choice"],
)
parser.add_argument("--moe_num_experts", default=8, type=int)
parser.add_argument( # only used for expert choice routing
"--capacity_factor", default=2.0, type=float
)
parser.add_argument( # deepseek routing, experts that are always active
"--moe_num_shared_experts", default=0, type=int
)
parser.add_argument(
"--moe_router_loss",
default="load_balancing_z_loss",
type=str,
choices=["entropy", "load_balancing_only", "load_balancing_z_loss"],
)
parser.add_argument("--moe_num_experts_per_tok", default=2, type=int)
parser.add_argument("--moe_entropy_loss_factor", default=0.01, type=float)
parser.add_argument("--moe_aux_loss_factor", default=0.1, type=float)
parser.add_argument("--moe_z_loss_factor", default=0.01, type=float)
parser.add_argument(
"--moe_softmax_order",
type=str,
default="topk_softmax",
choices=["softmax_topk", "topk_softmax"],
)
parser.add_argument("--plot_router_logits", action="store_true")

return parser.parse_args(args, namespace)
3 changes: 3 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ def get_exp_name(
"device",
"adema_beta3_warmup",
"adema_alpha_warmup",
"plot_router_logits",
],
):
# Get the default values
Expand All @@ -747,6 +748,8 @@ def get_exp_name(
for key in key_args:
if hasattr(args, key):
value = getattr(args, key)
if key == "model" and hasattr(args, "moe") and args.moe:
value = f"moe_{value}"
prefix_parts.append(f"{key}-{value}")

prefix = "_".join(prefix_parts)
Expand Down
170 changes: 158 additions & 12 deletions src/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import torch.nn as nn
from torch.nn import functional as F

from models.moe import (ExpertChoiceMoE, MoE, entropy_reg, load_balancing_loss,
router_z_loss)


class LayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
Expand Down Expand Up @@ -113,7 +116,7 @@ def forward(self, x):
x = self.activation(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
return x, {}


class Block(nn.Module):
Expand All @@ -124,20 +127,32 @@ def __init__(self, config):
self.parallel = config.parallel_block
if not self.parallel:
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)
if config.moe:
if config.moe_routing == "standard_gating":
self.mlp = MoE(config, MLP)
elif config.moe_routing == "expert_choice":
self.mlp = ExpertChoiceMoE(config, MLP)
elif config.moe_routing == "soft_moe":
self.mlp = SoftMoE(config, MLP)
elif config.moe_routing == "tree":
self.mlp = TreeRouter(config, MLP)
else:
raise ValueError(f"Unknown routing: {config.routing}")
else:
self.mlp = MLP(config)

def forward(self, x, *args, **kwargs):
if self.parallel:
# from GPT-J 6B https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L299
x_ln = self.ln_1(x, *args, **kwargs)
x_attn = self.attn(x_ln)
x_ffn = self.mlp(x_ln)
x_ffn, logits_and_experts = self.mlp(x_ln)
x = x + x_attn + x_ffn
else:
x = x + self.attn(self.ln_1(x, *args, **kwargs))
x_ = self.mlp(self.ln_2(x, *args, **kwargs))
x_, logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs))
x = x + x_
return x
return x, logits_and_experts


class GPTBase(nn.Module):
Expand Down Expand Up @@ -177,6 +192,37 @@ def __init__(self, config):
mean=0.0,
std=self.config.init_std / math.sqrt(2 * config.n_layer),
)
if pn.endswith("router.weight"):
# special scaled init to moe router?
with torch.no_grad():
dim = 1 if config.moe_routing == "standard_gating" else 0
std = p.std()
p.div_(p.sum(dim=dim, keepdim=True))
p.mul_(std / p.std())

def get_router_losses(self, logits, selected_experts, eval=False):
# logits: (b * seq_len, n_experts)
# selected_experts: (b * seq_len, topk)
if eval: # eval mode, compute all losses
return {
"moe_entropy_loss": entropy_reg(logits),
"moe_aux_loss": load_balancing_loss(logits, selected_experts),
"moe_z_loss": router_z_loss(logits),
}
if self.config.moe_router_loss == "entropy":
return {
"moe_entropy_loss": entropy_reg(logits),
}
elif self.config.moe_router_loss == "load_balancing_only":
return {
"moe_aux_loss": load_balancing_loss(logits, selected_experts),
}
elif self.config.moe_router_loss == "load_balancing_z_loss":
return {
"moe_aux_loss": load_balancing_loss(logits, selected_experts),
"moe_z_loss": router_z_loss(logits),
}
return {}

def get_num_params(self, non_embedding=True):
"""
Expand All @@ -198,7 +244,7 @@ def _init_weights(self, module):
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)

def forward(self, idx, targets=None, get_logits=False):
def forward(self, idx, targets=None, get_logits=False, moe=False):
device = idx.device
b, t = idx.size()
assert (
Expand All @@ -214,17 +260,42 @@ def forward(self, idx, targets=None, get_logits=False):
) # position embeddings of shape (1, t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)

# router logits is a list for each layer's routing, each of shape (b * seq_len, n_experts)
router_logits = []
# experts is a list for each layer's selected experts, shape (b * seq_len, topk)
experts = []

# forward pass through all the transformer blocks
for block in self.transformer.h:
x = block(x)
x, logits_and_experts = block(x)
if len(logits_and_experts) > 0:
router_logits.append(logits_and_experts["router_logits"])
experts.append(logits_and_experts["selected_experts"])
x = self.transformer.ln_f(x)

# aux_losses is a dict with keys for different auxiliary losses
aux_losses = {}

if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
)
if moe and self.config.moe_routing == "standard_gating":
# calculate the router losses per layer
for logit, expert_choice in zip(router_logits, experts):
router_losses = self.get_router_losses(
logit, expert_choice, eval=not self.training
)
for k, v in router_losses.items():
aux_losses[k] = aux_losses.get(k, 0.0) + v
if self.training:
loss += (
v
* getattr(self.config, k + "_factor")
/ self.config.n_layer
)

else:
# inference-time mini-optimization: only forward the lm_head on the very last position
Expand All @@ -233,9 +304,14 @@ def forward(self, idx, targets=None, get_logits=False):
) # note: using list [-1] to preserve the time dim
loss = None
logits = logits if get_logits else None
router_logits = (
torch.stack(router_logits, dim=0) if len(router_logits) > 0 else None
)
return {
"logits": logits,
"loss": loss,
"aux_losses": aux_losses,
"router_logits": router_logits,
}

def crop_sequence_length(self, sequence_length):
Expand All @@ -250,9 +326,60 @@ def crop_sequence_length(self, sequence_length):
for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length]

def convert_dense_to_sparse(self, state_dict):
"""
Convert the dense model to sparse model.
"""
state_to_load = {}
for k, v in state_dict.items():
vals = k.split(".")
print(vals)
if len(vals) >= 5 and vals[4] == "mlp":
# for layer i, go from '_orig_mod.transformer.h.i.mlp.c_fc.weight' to
# '_orig_mod.transformer.h.i.mlp.experts.e.c_fc.weight'
for e in range(self.config.moe_num_experts):
state_to_load[
".".join(vals[1:5] + ["experts", str(e)] + vals[5:])
] = v
# add router weight from already initialized weights above
state_to_load[".".join(vals[1:5] + ["router", "weight"])] = (
self.transformer.h[int(vals[3])].mlp.router.weight
)
else:
state_to_load[".".join(k.split(".")[1:])] = v
return state_to_load

def convert_n_dense_to_sparse(self, state_dicts):
"""
Convert the dense model to sparse model.
"""
assert (
len(state_dicts) == self.config.moe_num_experts
), f"len(state_dict)={len(state_dicts)} != {self.config.moe_num_experts}."
state_to_load = {}
for e in range(self.config.moe_num_experts):
state_dict = state_dicts[e]
for k, v in state_dict.items():
vals = k.split(".")
print(vals)
if len(vals) >= 5 and vals[4] == "mlp":
# for layer i, go from '_orig_mod.transformer.h.i.mlp.c_fc.weight' to
# '_orig_mod.transformer.h.i.mlp.experts.e.c_fc.weight'
state_to_load[
".".join(vals[1:5] + ["experts", str(e)] + vals[5:])
] = v
# add router weight from already initialized weights above
state_to_load[".".join(vals[1:5] + ["router", "weight"])] = (
self.transformer.h[int(vals[3])].mlp.router.weight
)
else:
state_to_load[".".join(k.split(".")[1:])] = v
return state_to_load

def from_pretrained(
self,
model_path,
from_dense: bool = True,
):
paths = model_path.split(",")
if len(paths) == 1:
Expand All @@ -263,11 +390,30 @@ def from_pretrained(
)
state_to_load = loaded_state["model"]

# load the sparse model
state_to_load = {
".".join(k.split(".")[1:]): v # drop _orig_mod from keys
for k, v in state_to_load.items()
}
if self.config.moe and from_dense:
# load the dense model and convert to sparse
state_to_load = self.convert_dense_to_sparse(state_to_load)
else:
# load the sparse model
state_to_load = {
".".join(k.split(".")[1:]): v # drop _orig_mod from keys
for k, v in state_to_load.items()
}
else:
loaded_states = []
for path in paths:
loaded_state = torch.load(
str(path + "/ckpt.pt"),
map_location=torch.device(self.config.device),
)
loaded_states.append(loaded_state["model"])
if self.config.moe and from_dense:
# load the dense model and convert to sparse
print(f"Loading from {len(paths)} dense models.")
state_to_load = self.convert_n_dense_to_sparse(loaded_states)
else:
raise NotImplementedError("Multiple paths -> load from dense.")
super().load_state_dict(state_to_load)

def get_parameter_group_specs(self):
"""
Expand Down
Loading