Skip to content

Added OLMo(E) v1 #816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ docs/build
docs/source/generated
**.orig
.venv

49 changes: 32 additions & 17 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def __init__(
)

self.cfg = HookedTransformerConfig.unwrap(cfg)

if tokenizer is not None:
self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
elif self.cfg.tokenizer_name is not None:
Expand All @@ -161,13 +160,18 @@ def __init__(
if "phi" in self.cfg.tokenizer_name.lower():
use_fast = False
huggingface_token = os.environ.get("HF_TOKEN", "")
add_bos_token = self.cfg.original_architecture not in [
"OlmoForCausalLM",
"OlmoeForCausalLM",
"Olmo2ForCausalLM",
]
self.set_tokenizer(
AutoTokenizer.from_pretrained(
self.cfg.tokenizer_name,
add_bos_token=True,
trust_remote_code=self.cfg.trust_remote_code,
use_fast=use_fast,
token=huggingface_token if len(huggingface_token) > 0 else None,
add_bos_token=add_bos_token,
),
default_padding_side=default_padding_side,
)
Expand Down Expand Up @@ -734,7 +738,14 @@ def set_tokenizer(
# tokenizers like LlamaTokenizer are different when bos token is automatically/manually
# prepended, and add_bos_token cannot be dynamically controlled after initialization
# (https://github.com/huggingface/transformers/issues/25886).
tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
if self.cfg.original_architecture not in [
"OlmoForCausalLM",
"OlmoeForCausalLM",
"Olmo2ForCausalLM",
]:
tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
else:
tokenizer_with_bos = tokenizer
self.tokenizer = tokenizer_with_bos
self.tokenizer.padding_side = default_padding_side

Expand Down Expand Up @@ -1798,18 +1809,18 @@ def fold_layer_norm(
if not self.cfg.final_rms and fold_biases:
# Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm
# pre unembed.
state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + (
state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None]
state_dict["unembed.b_U"] = state_dict["unembed.b_U"] + (
state_dict["unembed.W_U"] * state_dict["ln_final.b"][:, None]
).sum(dim=-2)
del state_dict[f"ln_final.b"]
del state_dict["ln_final.b"]

state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None]
del state_dict[f"ln_final.w"]
state_dict["unembed.W_U"] = state_dict["unembed.W_U"] * state_dict["ln_final.w"][:, None]
del state_dict["ln_final.w"]

if center_weights:
# Center the weights that read in from the LayerNormPre
state_dict[f"unembed.W_U"] -= einops.reduce(
state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean"
state_dict["unembed.W_U"] -= einops.reduce(
state_dict["unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean"
)

return state_dict
Expand All @@ -1821,13 +1832,17 @@ def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]):
W_out. This is done by subtracting the mean of the weights from the weights themselves. This
is done in-place. See fold_layer_norm for more details.
"""
state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean(
-1, keepdim=True
)
if self.cfg.positional_embedding_type != "rotary":
state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[
"pos_embed.W_pos"
].mean(-1, keepdim=True)
if self.cfg.original_architecture == "Olmo2ForCausalLM":
print("Not centering embedding weights for Olmo2ForCausalLM")
pass # should not because input of attn of 1st layer is not normed
else:
state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean(
-1, keepdim=True
)
if self.cfg.positional_embedding_type != "rotary":
state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[
"pos_embed.W_pos"
].mean(-1, keepdim=True)
for l in range(self.cfg.n_layers):
state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[
f"blocks.{l}.attn.W_O"
Expand Down
4 changes: 2 additions & 2 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ class HookedTransformerConfig:
NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that
affects the rate of change between low and high-frequency interpolation strategies.
Defaults to 8.0.


norm_topk_prob (bool): Whether to normalize the top-k probabilities in the MoE layer.
"""

n_layers: int
Expand Down Expand Up @@ -264,6 +263,7 @@ class HookedTransformerConfig:
NTK_by_parts_high_freq_factor: float = 4.0
NTK_by_parts_factor: float = 8.0
NTK_original_ctx_len: int = 8192
norm_topk_prob: bool = False

def __post_init__(self):
if self.n_heads == -1:
Expand Down
48 changes: 44 additions & 4 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ def __init__(
# will be overwritten by the child T5Attention class
self.has_relative_attention_bias = False

if (
self.cfg.original_architecture == "OlmoeForCausalLM"
or self.cfg.original_architecture == "Olmo2ForCausalLM"
):
self.q_norm = RMSNorm(self.cfg, self.cfg.d_model)
k_norm_dim = (
self.cfg.d_model
if self.cfg.original_architecture == "Olmo2ForCausalLM"
else self.cfg.d_head * self.cfg.n_key_value_heads
)
self.k_norm = RMSNorm(self.cfg, k_norm_dim)

@property
def OV(self) -> FactoredMatrix:
"""
Expand Down Expand Up @@ -209,6 +221,32 @@ def forward(

q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)

# OLMoE uses QK-norm.
if (
self.cfg.original_architecture == "OlmoeForCausalLM"
or self.cfg.original_architecture == "Olmo2ForCausalLM"
):
q = einops.rearrange(
self.q_norm(
einops.rearrange(
q,
"batch pos head_index d_head -> batch pos (head_index d_head)",
)
),
"batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
head_index=q.shape[2],
)
k = einops.rearrange(
self.k_norm(
einops.rearrange(
k,
"batch pos head_index d_head -> batch pos (head_index d_head)",
)
),
"batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
head_index=k.shape[2],
)

if past_kv_cache_entry is not None:
# Appends the new keys and values to the cached values, and automatically updates the cache
kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
Expand Down Expand Up @@ -244,9 +282,10 @@ def forward(
)

# Take the last query_ctx positions so it also works with past_kv_cache
attn_scores += self.alibi[
:, -query_ctx:, :key_ctx
] # [batch, head_index, query_pos, key_pos]
if self.alibi is not None: # Add None check
attn_scores += self.alibi[
:, -query_ctx:, :key_ctx
] # [batch, head_index, query_pos, key_pos]
elif self.cfg.positional_embedding_type == "relative_positional_bias":
if position_bias is None:
if self.has_relative_attention_bias:
Expand All @@ -260,7 +299,8 @@ def forward(
device=attn_scores.device,
)

attn_scores += position_bias
if position_bias is not None: # Add None check
attn_scores += position_bias
if self.cfg.attention_dir == "causal":
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
attn_scores = self.apply_causal_mask(
Expand Down
3 changes: 2 additions & 1 deletion transformer_lens/components/mlps/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def forward(
# both are [batch, pos, experts_per_token]
weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float))
weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
if self.cfg.norm_topk_prob:
weights /= weights.sum(dim=-1, keepdim=True)
expert_indices = self.hook_expert_indices(expert_indices)
weights = weights.to(x.dtype)

Expand Down
41 changes: 28 additions & 13 deletions transformer_lens/components/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,26 +153,37 @@ def forward(
key_input = attn_in
value_input = attn_in

attn_out = (
# hook the residual stream states that are used to calculate the
# queries, keys and values, independently.
# Then take the layer norm of these inputs, and pass these to the attention module.
self.attn(
query_input=self.ln1(query_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
key_input=self.ln1(key_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
value_input=self.ln1(value_input),
if self.cfg.original_architecture == "Olmo2ForCausalLM":
attn_out = self.attn(
query_input=query_input,
key_input=key_input,
value_input=value_input,
past_kv_cache_entry=past_kv_cache_entry,
attention_mask=attention_mask,
)
) # [batch, pos, d_model]
else:
attn_out = (
# hook the residual stream states that are used to calculate the
# queries, keys and values, independently.
# Then take the layer norm of these inputs, and pass these to the attention module.
self.attn(
query_input=self.ln1(query_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
key_input=self.ln1(key_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
value_input=self.ln1(value_input),
past_kv_cache_entry=past_kv_cache_entry,
attention_mask=attention_mask,
)
) # [batch, pos, d_model]
if self.cfg.use_normalization_before_and_after:
# If we use LayerNorm both before and after, then apply the second LN after the layer
# and before the hook. We do it before the hook so hook_attn_out captures "that which
# is added to the residual stream"
attn_out = self.ln1_post(attn_out)
attn_out = self.hook_attn_out(attn_out)
if self.cfg.original_architecture == "Olmo2ForCausalLM":
attn_out = self.ln1(attn_out)

if resid_pre.device != attn_out.device:
resid_pre = resid_pre.to(attn_out.device)
Expand All @@ -182,8 +193,12 @@ def forward(
mlp_in = (
resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone())
)
normalized_resid_mid = self.ln2(mlp_in)
mlp_out = self.apply_mlp(normalized_resid_mid)
if self.cfg.original_architecture == "Olmo2ForCausalLM":
mlp_out = self.apply_mlp(mlp_in)
mlp_out = self.ln2(mlp_out)
else:
normalized_resid_mid = self.ln2(mlp_in)
mlp_out = self.apply_mlp(normalized_resid_mid)
resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model]
elif self.cfg.parallel_attn_mlp:
# Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used.
Expand Down
101 changes: 101 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
convert_neel_solu_old_weights,
convert_neo_weights,
convert_neox_weights,
convert_olmo2_weights,
convert_olmo_weights,
convert_olmoe_weights,
convert_opt_weights,
convert_phi3_weights,
convert_phi_weights,
Expand Down Expand Up @@ -263,6 +266,20 @@
"google-t5/t5-base",
"google-t5/t5-large",
"ai-forever/mGPT",
"allenai/OLMo-1B-hf",
"allenai/OLMo-7B-hf",
"allenai/OLMo-7B-0724-hf",
"allenai/OLMo-7B-0724-SFT-hf",
"allenai/OLMo-7B-0724-Instruct-hf",
"allenai/OLMo-7B-0424-hf",
"allenai/OLMo-7B-Twin-2T-hf",
"allenai/OLMo-1B-0724-hf",
"allenai/OLMo-7B-Instruct-hf",
"allenai/OLMo-7B-SFT-hf",
"allenai/OLMoE-1B-7B-0924",
"allenai/OLMoE-1B-7B-0924-SFT",
"allenai/OLMoE-1B-7B-0924-Instruct",
"allenai/OLMo-2-1124-7B",
]
"""Official model names for models on HuggingFace."""

Expand Down Expand Up @@ -1563,6 +1580,84 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
"final_rms": True,
"use_normalization_before_and_after": True,
}
elif official_model_name.startswith("allenai/OLMo-1B") and official_model_name.endswith("hf"):
cfg_dict = {
"d_model": 2048,
"d_head": 128,
"n_heads": 16,
"d_mlp": 8192,
"n_layers": 16,
"n_ctx": 2048,
"eps": 1e-05,
"d_vocab": 50304,
"act_fn": "silu",
"initializer_range": 0.02,
"normalization_type": "LN",
"rotary_base": 10000.0,
"attn_types": ["global"] * 16,
"positional_embedding_type": "rotary",
"gated_mlp": True,
}
elif official_model_name.startswith("allenai/OLMo-7B") and official_model_name.endswith("hf"):
cfg_dict = {
"d_model": 4096,
"d_head": 128,
"n_heads": 32,
"d_mlp": 11008,
"n_layers": 32,
"n_ctx": 2048,
"eps": 1e-05,
"d_vocab": 50304,
"act_fn": "silu",
"initializer_range": 0.02,
"normalization_type": "LN",
"rotary_base": 10000.0,
"attn_types": ["global"] * 32,
"positional_embedding_type": "rotary",
"gated_mlp": True,
}
elif official_model_name == "allenai/OLMo-2-1124-7B":
cfg_dict = {
"d_model": 4096,
"d_head": 128,
"n_heads": 32,
"d_mlp": 11008,
"n_layers": 32,
"n_ctx": 4096,
"eps": 1e-06,
"d_vocab": 100352,
"act_fn": "silu",
"initializer_range": 0.02,
"normalization_type": "RMSPre",
"rotary_base": 500000.0,
"attn_types": ["global"] * 32,
"positional_embedding_type": "rotary",
"gated_mlp": True,
}
elif architecture == "OlmoeForCausalLM":
cfg_dict = {
"d_model": hf_config.hidden_size,
"d_head": hf_config.hidden_size // hf_config.num_attention_heads,
"n_heads": hf_config.num_attention_heads,
"d_mlp": hf_config.intermediate_size,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": hf_config.max_position_embeddings,
"eps": hf_config.rms_norm_eps,
"d_vocab": hf_config.vocab_size,
"act_fn": hf_config.hidden_act,
"num_experts": hf_config.num_experts,
"experts_per_token": hf_config.num_experts_per_tok,
"norm_topk_prob": hf_config.norm_topk_prob,
"n_key_value_heads": hf_config.num_key_value_heads,
"rotary_base": hf_config.rope_theta,
"tie_word_embeddings": hf_config.tie_word_embeddings,
"initializer_range": hf_config.initializer_range,
"positional_embedding_type": "rotary",
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
"final_rms": True,
"gated_mlp": True,
"normalization_type": "LN",
}
elif architecture == "T5ForConditionalGeneration":
cfg_dict = {
"d_model": hf_config.d_model,
Expand Down Expand Up @@ -1986,6 +2081,12 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma2ForCausalLM":
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "OlmoForCausalLM":
state_dict = convert_olmo_weights(hf_model, cfg)
elif cfg.original_architecture == "Olmo2ForCausalLM":
state_dict = convert_olmo2_weights(hf_model, cfg)
elif cfg.original_architecture == "OlmoeForCausalLM":
state_dict = convert_olmoe_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
Expand Down
Loading
Loading