Skip to content

Model deepseek #949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: dev-3.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions transformer_lens/boot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@

def boot(
model_name: str,
config: dict | None = None,
model_config: dict | None = None,
tokenizer_config: dict | None = None,
device: str | torch.device | None = None,
dtype: torch.dtype = torch.float32,
**kwargs,
) -> TransformerBridge:
"""Boot a model from HuggingFace.

Args:
model_name: The name of the model to load.
config: The config dict to use. If None, will be loaded from HuggingFace.
model_config: Additional configuration parameters to override the default config.
tokenizer_config: The config dict to use for tokenizer loading. If None, will use default settings.
device: The device to use. If None, will be determined automatically.
dtype: The dtype to use for the model.
**kwargs: Additional keyword arguments for from_pretrained.

Returns:
The bridge to the loaded model.
"""
hf_config = AutoConfig.from_pretrained(model_name, **kwargs)
hf_config = AutoConfig.from_pretrained(model_name, **(model_config or {}))
adapter = ArchitectureAdapterFactory.select_architecture_adapter(hf_config)
default_config = adapter.default_cfg
merged_config = {**default_config, **(config or {})}
merged_config = {**default_config, **(model_config or {})}

# Load the model from HuggingFace using the original config
hf_model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -44,7 +44,7 @@ def boot(
)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_name, **(tokenizer_config or {}))

return TransformerBridge(
hf_model,
Expand Down
2 changes: 2 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformer_lens.model_bridge.supported_architectures import (
BertArchitectureAdapter,
BloomArchitectureAdapter,
DeepseekArchitectureAdapter,
Gemma1ArchitectureAdapter,
Gemma2ArchitectureAdapter,
Gemma3ArchitectureAdapter,
Expand All @@ -35,6 +36,7 @@
SUPPORTED_ARCHITECTURES = {
"BertForMaskedLM": BertArchitectureAdapter,
"BloomForCausalLM": BloomArchitectureAdapter,
"DeepseekV3ForCausalLM": DeepseekArchitectureAdapter,
"GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version
"Gemma1ForCausalLM": Gemma1ArchitectureAdapter,
"Gemma2ForCausalLM": Gemma2ArchitectureAdapter,
Expand Down
32 changes: 7 additions & 25 deletions transformer_lens/model_bridge/generalized_components/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(
# Add all the hooks from the old attention components
self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
self.hook_v = HookPoint() # Value vectors
self.hook_z = HookPoint() # Attention output
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_result = HookPoint() # [batch, pos, head_index, d_model]
Expand All @@ -50,31 +50,13 @@ def __init__(
self.hook_rot_q = HookPoint() # [batch, pos, head_index, d_head] (for rotary)

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Forward pass through the attention layer.

This method forwards all arguments to the original component and applies hooks
to the output. The arguments should match the original component's forward method.
"""Forward pass through the attention bridge.

Args:
*args: Input arguments to pass to the original component
**kwargs: Input keyword arguments to pass to the original component
*args: Positional arguments for the original component
**kwargs: Keyword arguments for the original component

Returns:
The output from the original component, with hooks applied
Output from the original component
"""
# Handle hook_attn_input for shortformer positional embeddings
if "query_input" in kwargs:
# Combine normalized residual stream with positional embeddings
attn_input = kwargs["query_input"]
# Pass through hook_attn_input
attn_input = self.hook_attn_input(attn_input)
# Update query_input with the hooked value
kwargs["query_input"] = attn_input

# Forward through the original component
output = self.original_component(*args, **kwargs)

# Execute hooks on the output (for add_hook compatibility)
output = self.execute_hooks("output", output)

return output
return self.original_component(*args, **kwargs)
18 changes: 0 additions & 18 deletions transformer_lens/model_bridge/generalized_components/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,6 @@ def __init__(
"""
super().__init__(original_component, name, architecture_adapter)

@classmethod
def wrap_component(
cls, component: nn.Module, name: str, architecture_adapter: ArchitectureAdapter
) -> nn.Module:
"""Wrap a component with this bridge if it's a transformer block.

Args:
component: The component to wrap
name: The name of the component
architecture_adapter: The architecture adapter instance

Returns:
The wrapped component if it's a transformer block, otherwise the original component
"""
if name.endswith(".block") or name.endswith(".layer"):
return cls(component, name, architecture_adapter)
return component

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Forward pass through the block bridge.

Expand Down
18 changes: 0 additions & 18 deletions transformer_lens/model_bridge/generalized_components/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,3 @@ def forward(
self.hook_outputs.update({"output": output})

return output

@classmethod
def wrap_component(
cls, component: nn.Module, name: str, architecture_adapter: ArchitectureAdapter
) -> nn.Module:
"""Wrap a component with this bridge if it's an embedding layer.

Args:
component: The component to wrap
name: The name of the component
architecture_adapter: The architecture adapter instance

Returns:
The wrapped component if it's an embedding layer, otherwise the original component
"""
if name.endswith(".embed") or name.endswith(".embed_tokens"):
return cls(component, name, architecture_adapter)
return component
18 changes: 0 additions & 18 deletions transformer_lens/model_bridge/generalized_components/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,3 @@ def forward(
self.hook_outputs.update({"output": output})

return output

@classmethod
def wrap_component(
cls, component: nn.Module, name: str, architecture_adapter: ArchitectureAdapter
) -> nn.Module:
"""Wrap a component with this bridge if it's a LayerNorm layer.

Args:
component: The component to wrap
name: The name of the component
architecture_adapter: The architecture adapter instance

Returns:
The wrapped component if it's a LayerNorm layer, otherwise the original component
"""
if name.endswith(".ln") or name.endswith(".ln1") or name.endswith(".ln2"):
return cls(component, name, architecture_adapter)
return component
18 changes: 0 additions & 18 deletions transformer_lens/model_bridge/generalized_components/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,3 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
self.hook_outputs.update({"output": output})

return output

@classmethod
def wrap_component(
cls, component: nn.Module, name: str, architecture_adapter: ArchitectureAdapter
) -> nn.Module:
"""Wrap a component with this bridge if it's an MLP layer.

Args:
component: The component to wrap
name: The name of the component
architecture_adapter: The architecture adapter instance

Returns:
The wrapped component if it's an MLP layer, otherwise the original component
"""
if name.endswith(".mlp"):
return cls(component, name, architecture_adapter)
return component
18 changes: 0 additions & 18 deletions transformer_lens/model_bridge/generalized_components/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,6 @@ def __init__(
"""
super().__init__(original_component, name, architecture_adapter)

@classmethod
def wrap_component(
cls, component: nn.Module, name: str, architecture_adapter: ArchitectureAdapter
) -> nn.Module:
"""Wrap a component with this bridge if it's a MoE layer.

Args:
component: The component to wrap
name: The name of the component
architecture_adapter: The architecture adapter instance

Returns:
The wrapped component if it's a MoE layer, otherwise the original component
"""
if name.endswith(".moe"):
return cls(component, name, architecture_adapter)
return component

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Forward pass through the MoE bridge.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,3 @@ def forward(
self.hook_outputs.update({"output": output})

return output

@classmethod
def wrap_component(
cls, component: nn.Module, name: str, architecture_adapter: ArchitectureAdapter
) -> nn.Module:
"""Wrap a component with this bridge if it's an unembedding layer.

Args:
component: The component to wrap
name: The name of the component
architecture_adapter: The architecture adapter instance

Returns:
The wrapped component if it's an unembedding layer, otherwise the original component
"""
if name.endswith(".unembed") or name.endswith(".lm_head"):
return cls(component, name, architecture_adapter)
return component
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from transformer_lens.model_bridge.supported_architectures.bloom import (
BloomArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.deepseek import (
DeepseekArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.gemma1 import (
Gemma1ArchitectureAdapter,
)
Expand Down Expand Up @@ -76,6 +79,7 @@
__all__ = [
"BertArchitectureAdapter",
"BloomArchitectureAdapter",
"DeepseekArchitectureAdapter",
"Gemma1ArchitectureAdapter",
"Gemma2ArchitectureAdapter",
"Gemma3ArchitectureAdapter",
Expand Down
70 changes: 70 additions & 0 deletions transformer_lens/model_bridge/supported_architectures/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""DeepSeek architecture adapter."""

from typing import Any

from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
from transformer_lens.model_bridge.conversion_utils.conversion_steps import (
WeightConversionSet,
)
from transformer_lens.model_bridge.generalized_components import (
AttentionBridge,
BlockBridge,
EmbeddingBridge,
LayerNormBridge,
MLPBridge,
MoEBridge,
UnembeddingBridge,
)


class DeepseekArchitectureAdapter(ArchitectureAdapter):
"""Architecture adapter for DeepSeek models."""

def __init__(self, cfg: Any) -> None:
"""Initialize the DeepSeek architecture adapter.

Args:
cfg: The configuration object.
"""
super().__init__(cfg)

self.conversion_rules = WeightConversionSet(
{
"embed.W_E": "model.embed_tokens.weight",
"blocks.{i}.ln1.w": "model.layers.{i}.input_layernorm.weight",
# Attention weights
"blocks.{i}.attn.W_Q": "model.layers.{i}.self_attn.q_proj.weight",
"blocks.{i}.attn.W_K": "model.layers.{i}.self_attn.k_proj.weight",
"blocks.{i}.attn.W_V": "model.layers.{i}.self_attn.v_proj.weight",
"blocks.{i}.attn.W_O": "model.layers.{i}.self_attn.o_proj.weight",
"blocks.{i}.ln2.w": "model.layers.{i}.post_attention_layernorm.weight",
# MLP weights for dense layers
"blocks.{i}.mlp.W_gate": "model.layers.{i}.mlp.gate_proj.weight",
"blocks.{i}.mlp.W_in": "model.layers.{i}.mlp.up_proj.weight",
"blocks.{i}.mlp.W_out": "model.layers.{i}.mlp.down_proj.weight",
# MoE weights
"blocks.{i}.moe.gate.w": "model.layers.{i}.mlp.gate.weight",
"blocks.{i}.moe.experts.W_gate.{j}": "model.layers.{i}.mlp.experts.{j}.gate_proj.weight",
"blocks.{i}.moe.experts.W_in.{j}": "model.layers.{i}.mlp.experts.{j}.up_proj.weight",
"blocks.{i}.moe.experts.W_out.{j}": "model.layers.{i}.mlp.experts.{j}.down_proj.weight",
"ln_final.w": "model.norm.weight",
"unembed.W_U": "lm_head.weight",
}
)

self.component_mapping = {
"embed": ("model.embed_tokens", EmbeddingBridge),
"blocks": (
"model.layers",
BlockBridge,
{
"ln1": ("input_layernorm", LayerNormBridge),
"ln2": ("post_attention_layernorm", LayerNormBridge),
"attn": ("self_attn", AttentionBridge),
"mlp": ("mlp", MLPBridge),
"moe": ("mlp", MoEBridge),
},
),
"ln_final": ("model.norm", LayerNormBridge),
"unembed": ("lm_head", UnembeddingBridge),
}
Loading