Skip to content

Bridged Robust Model Structure #960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 50 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d98f3a5
added generalized hooks
bryce13950 Jul 8, 2025
fed9088
created linnear generalized component
bryce13950 Jul 8, 2025
dde7e9c
updated component mapping to use configured instances
bryce13950 Jul 8, 2025
60de6d4
updated module bridge to properly init modules where it should be
bryce13950 Jul 8, 2025
af25fc1
simplified types
bryce13950 Jul 8, 2025
ab39667
updated remaining architectures to use new structure
bryce13950 Jul 8, 2025
5480090
configured cache properly
bryce13950 Jul 8, 2025
b48e99c
ran format
bryce13950 Jul 9, 2025
5be0435
removed debug prints
bryce13950 Jul 9, 2025
ff82ddb
fixed some tests
bryce13950 Jul 10, 2025
8137eaa
Merge branch 'dev-3.x' into bridge-robust-model-structure
bryce13950 Jul 10, 2025
d4f3b24
ran format
bryce13950 Jul 10, 2025
6f1abea
restored function
bryce13950 Jul 10, 2025
8cbd010
fixed some tests
bryce13950 Jul 11, 2025
aa40078
fixed remaining unit tests
bryce13950 Jul 11, 2025
d0f2ab4
removed extra test
bryce13950 Jul 11, 2025
914189e
fixed format
bryce13950 Jul 11, 2025
7d2b645
fixed docstring
bryce13950 Jul 14, 2025
c652b28
moved setup stuff to its own file
bryce13950 Jul 14, 2025
1332f36
organized submodules properly
bryce13950 Jul 14, 2025
69c7fef
simplified list creation
bryce13950 Jul 14, 2025
79bf22c
fixed up submodules properly
bryce13950 Jul 14, 2025
653d1ec
generalized setup more
bryce13950 Jul 14, 2025
6e6657e
simpliifed component mapping
bryce13950 Jul 14, 2025
e03fe46
updated functions for recent changes
bryce13950 Jul 14, 2025
09cd0ed
ran format
bryce13950 Jul 14, 2025
a35fde8
fleshed out test
bryce13950 Jul 14, 2025
dbf6cef
restored test
bryce13950 Jul 14, 2025
6ec7f10
ran format
bryce13950 Jul 14, 2025
ca0fb51
fixed mypy issues
bryce13950 Jul 14, 2025
e72549c
ran format
bryce13950 Jul 14, 2025
2b502d2
passed tests
bryce13950 Jul 14, 2025
c08fc33
fixed test
bryce13950 Jul 14, 2025
43a386b
moved d_mlp config
bryce13950 Jul 14, 2025
f8bf18b
finished refactor
bryce13950 Jul 14, 2025
8d2ea4a
ran format
bryce13950 Jul 14, 2025
9031d6e
ran format
bryce13950 Jul 14, 2025
b2a0150
removed comment
bryce13950 Jul 14, 2025
5e2dfba
removed extra checks
bryce13950 Jul 14, 2025
730ad72
improved mypy coverage
bryce13950 Jul 14, 2025
cc7c77a
updated the way hook points are created
bryce13950 Jul 15, 2025
a92e128
removed old hook implementation
bryce13950 Jul 15, 2025
05a409b
fixed infinite loop issues
bryce13950 Jul 15, 2025
e02af81
removed extra hook
bryce13950 Jul 15, 2025
9546e79
restored hook
bryce13950 Jul 15, 2025
7e7e08b
fixed hook management in embed
bryce13950 Jul 15, 2025
8050877
removed extra comments
bryce13950 Jul 15, 2025
a323a4e
moved component to main dictionary
bryce13950 Jul 15, 2025
e1358c7
added input capability
bryce13950 Jul 15, 2025
38d274b
ran format
bryce13950 Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/integration/model_bridge/test_bridge_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ def test_cache():
assert isinstance(cache, ActivationCache), "Cache should be an ActivationCache object"
assert len(cache) > 0, "Cache should contain activations"

# Verify cache contains some expected keys (using actual HuggingFace model structure)
# Verify cache contains some expected keys (using TransformerLens naming convention)
# The exact keys depend on the model architecture, but we should have some basic ones
cache_keys = list(cache.keys())
assert any("wte" in key for key in cache_keys), "Cache should contain word token embeddings"
assert any("ln_f" in key for key in cache_keys), "Cache should contain final layer norm"
assert any("embed" in key for key in cache_keys), "Cache should contain word token embeddings"
assert any("ln_final" in key for key in cache_keys), "Cache should contain final layer norm"
assert any("lm_head" in key for key in cache_keys), "Cache should contain language model head"

# Verify that cached tensors are actually tensors
Expand Down
59 changes: 40 additions & 19 deletions tests/mocks/architecture_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,55 @@ class MockArchitectureAdapter(ArchitectureAdapter):
"""Mock architecture adapter for testing."""

def __init__(self, cfg=None):
if cfg is None:
# Create a minimal config for testing
cfg = type(
"MockConfig",
(),
{"d_mlp": 512, "intermediate_size": 512, "default_prepend_bos": True},
)()
super().__init__(cfg)
# Use actual bridge instances instead of tuples
self.component_mapping = {
"embed": ("embed", EmbeddingBridge),
"unembed": ("unembed", EmbeddingBridge),
"ln_final": ("ln_final", LayerNormBridge),
"blocks": (
"blocks",
BlockBridge,
{
"ln1": ("ln1", LayerNormBridge),
"ln2": ("ln2", LayerNormBridge),
"attn": ("attn", AttentionBridge),
"mlp": ("mlp", MLPBridge),
"embed": EmbeddingBridge(name="embed"),
"unembed": EmbeddingBridge(name="unembed"),
"ln_final": LayerNormBridge(name="ln_final"),
"blocks": BlockBridge(
name="blocks",
submodules={
"ln1": LayerNormBridge(name="ln1"),
"ln2": LayerNormBridge(name="ln2"),
"attn": AttentionBridge(name="attn"),
"mlp": MLPBridge(name="mlp"),
},
),
"outer_blocks": (
"outer_blocks",
BlockBridge,
{
"inner_blocks": (
"inner_blocks",
BlockBridge,
{"ln": ("ln", LayerNormBridge)},
"outer_blocks": BlockBridge(
name="outer_blocks",
submodules={
"inner_blocks": BlockBridge(
name="inner_blocks",
submodules={"ln": LayerNormBridge(name="ln")},
)
},
),
}

# Set up the submodules properly by registering them as PyTorch modules
self._setup_mock_submodules()

def _setup_mock_submodules(self):
"""Set up submodules for testing by registering them as PyTorch modules."""
for component_name, component in self.component_mapping.items():
self._register_submodules(component)

def _register_submodules(self, component):
"""Recursively register submodules for a component."""
if component.submodules:
for submodule_name, submodule in component.submodules.items():
component.add_module(submodule_name, submodule)
# Recursively register nested submodules
self._register_submodules(submodule)


@pytest.fixture
def mock_adapter() -> MockArchitectureAdapter:
Expand Down
1 change: 1 addition & 0 deletions tests/mocks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ def __init__(self):
layer.mlp.gate_proj = nn.Linear(512, 2048)
layer.mlp.down_proj = nn.Linear(2048, 512)
self.model.norm = nn.LayerNorm(512)
self.lm_head = nn.Linear(512, 1000) # Add missing lm_head
self.embed_tokens = self.model.embed_tokens # For shared embedding/unembedding
172 changes: 90 additions & 82 deletions tests/unit/model_bridge/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,61 @@ class TestTransformerBridge:
@pytest.fixture(autouse=True)
def setup_method(self, mock_adapter, mock_model_adapter):
"""Set up test fixtures."""

# Mock the get_component method to return expected components for formatting tests
def mock_get_component(model, path):
# Return mock bridge components for testing
if "embed" in path:
comp = EmbeddingBridge(name="embed")
comp.set_original_component(model.embed)
return comp
elif "ln_final" in path:
comp = LayerNormBridge(name="ln_final")
comp.set_original_component(model.ln_final)
return comp
elif "unembed" in path:
comp = EmbeddingBridge(name="unembed")
comp.set_original_component(model.unembed)
return comp
elif "blocks" in path and "attn" in path:
comp = AttentionBridge(name="attn")
comp.set_original_component(model.blocks[0].attn)
return comp
elif "blocks" in path and "mlp" in path:
comp = MLPBridge(name="mlp")
comp.set_original_component(model.blocks[0].mlp)
return comp
elif "blocks" in path and "ln1" in path:
comp = LayerNormBridge(name="ln1")
comp.set_original_component(model.blocks[0].ln1)
return comp
elif "blocks" in path and "ln2" in path:
comp = LayerNormBridge(name="ln2")
comp.set_original_component(model.blocks[0].ln2)
return comp
elif "blocks" in path:
comp = BlockBridge(name="blocks")
comp.set_original_component(model.blocks[0])
return comp
else:
# Return a generic component for unknown paths
comp = EmbeddingBridge(name="unknown")
return comp

mock_adapter.get_component = mock_get_component
self.bridge = TransformerBridge(mock_model_adapter, mock_adapter, MagicMock())
mock_adapter.user_cfg = MagicMock()
self.bridge.cfg = mock_adapter.user_cfg
mock_adapter.cfg = MagicMock()
self.bridge.cfg = mock_adapter.cfg

def test_format_remote_import_tuple(self):
"""Test formatting of RemoteImport tuples (like embed, ln_final, unembed)."""
# This is the case that was causing the original bug
"""Test formatting of bridge instances (like embed, ln_final, unembed)."""
# Updated to use actual bridge instances instead of tuples
mapping = {
"embed": ("embed", EmbeddingBridge),
"ln_final": ("ln_final", LayerNormBridge),
"unembed": ("unembed", EmbeddingBridge),
"embed": EmbeddingBridge(name="embed"),
"ln_final": LayerNormBridge(name="ln_final"),
"unembed": EmbeddingBridge(name="unembed"),
}
self.bridge.bridge.component_mapping = mapping
self.bridge.adapter.component_mapping = mapping

result = self.bridge._format_component_mapping(mapping, indent=1)

Expand All @@ -50,20 +92,19 @@ def test_format_remote_import_tuple(self):
assert line.startswith(" ") # 1 level of indentation

def test_format_block_mapping_tuple(self):
"""Test formatting of BlockMapping tuples (like blocks)."""
"""Test formatting of BlockBridge instances (like blocks)."""
mapping = {
"blocks": (
"blocks",
BlockBridge,
{
"ln1": ("ln1", LayerNormBridge),
"ln2": ("ln2", LayerNormBridge),
"attn": ("attn", AttentionBridge),
"mlp": ("mlp", MLPBridge),
"blocks": BlockBridge(
name="blocks",
submodules={
"ln1": LayerNormBridge(name="ln1"),
"ln2": LayerNormBridge(name="ln2"),
"attn": AttentionBridge(name="attn"),
"mlp": MLPBridge(name="mlp"),
},
)
}
self.bridge.bridge.component_mapping = mapping
self.bridge.adapter.component_mapping = mapping

result = self.bridge._format_component_mapping(mapping, indent=1)

Expand All @@ -76,20 +117,19 @@ def test_format_block_mapping_tuple(self):
assert " mlp:" in result[4]

def test_format_mixed_mapping(self):
"""Test formatting of a mapping with both RemoteImport and BlockMapping tuples."""
"""Test formatting of a mapping with both simple and block bridge instances."""
mapping = {
"embed": ("embed", EmbeddingBridge),
"blocks": (
"blocks",
BlockBridge,
{
"ln1": ("ln1", LayerNormBridge),
"attn": ("attn", AttentionBridge),
"embed": EmbeddingBridge(name="embed"),
"blocks": BlockBridge(
name="blocks",
submodules={
"ln1": LayerNormBridge(name="ln1"),
"attn": AttentionBridge(name="attn"),
},
),
"ln_final": ("ln_final", LayerNormBridge),
"ln_final": LayerNormBridge(name="ln_final"),
}
self.bridge.bridge.component_mapping = mapping
self.bridge.adapter.component_mapping = mapping

result = self.bridge._format_component_mapping(mapping, indent=0)

Expand All @@ -104,15 +144,14 @@ def test_format_mixed_mapping(self):
def test_format_with_prepend_path(self):
"""Test formatting with prepend path parameter."""
mapping = {
"ln1": ("ln1", LayerNormBridge),
"attn": ("attn", AttentionBridge),
"ln1": LayerNormBridge(name="ln1"),
"attn": AttentionBridge(name="attn"),
}
# To test prepending, we need a parent structure in the component mapping
self.bridge.bridge.component_mapping = {
"blocks": (
"blocks",
BlockBridge,
mapping,
self.bridge.adapter.component_mapping = {
"blocks": BlockBridge(
name="blocks",
submodules=mapping,
)
}

Expand All @@ -126,18 +165,18 @@ def test_format_with_prepend_path(self):
def test_format_empty_mapping(self):
"""Test formatting of an empty mapping."""
mapping = {}
self.bridge.bridge.component_mapping = mapping
self.bridge.adapter.component_mapping = mapping

result = self.bridge._format_component_mapping(mapping, indent=1)

assert result == []

def test_format_non_tuple_values(self):
"""Test formatting when mapping contains non-tuple values."""
def test_format_non_bridge_values(self):
"""Test formatting when mapping contains non-bridge values."""
mapping = {
"some_component": "simple_string_value",
}
self.bridge.bridge.component_mapping = mapping
self.bridge.adapter.component_mapping = mapping

result = self.bridge._format_component_mapping(mapping, indent=1)

Expand All @@ -147,21 +186,19 @@ def test_format_non_tuple_values(self):
def test_format_nested_block_mappings(self):
"""Test formatting of nested block mappings."""
mapping = {
"outer_blocks": (
"outer_blocks",
BlockBridge,
{
"inner_blocks": (
"inner_blocks",
BlockBridge,
{
"ln": ("ln", LayerNormBridge),
"outer_blocks": BlockBridge(
name="outer_blocks",
submodules={
"inner_blocks": BlockBridge(
name="inner_blocks",
submodules={
"ln": LayerNormBridge(name="ln"),
},
)
},
)
}
self.bridge.bridge.component_mapping = mapping
self.bridge.adapter.component_mapping = mapping

result = self.bridge._format_component_mapping(mapping, indent=0)

Expand All @@ -174,9 +211,9 @@ def test_format_nested_block_mappings(self):
def test_format_component_mapping_error_handling(self):
"""Test that the method handles errors gracefully when components can't be found."""
mapping = {
"nonexistent_component": ("path.to.nowhere", EmbeddingBridge),
"nonexistent_component": EmbeddingBridge(name="path.to.nowhere"),
}
self.bridge.bridge.component_mapping = mapping
self.bridge.adapter.component_mapping = mapping

# This should not raise an exception, but should handle the error in _format_single_component
result = self.bridge._format_component_mapping(mapping, indent=1)
Expand All @@ -188,9 +225,9 @@ def test_format_component_mapping_error_handling(self):
def test_indentation_levels(self):
"""Test that indentation is applied correctly at different levels."""
mapping = {
"level0": ("embed", EmbeddingBridge),
"level0": EmbeddingBridge(name="embed"),
}
self.bridge.bridge.component_mapping = mapping
self.bridge.adapter.component_mapping = mapping

# Test different indentation levels
result_0 = self.bridge._format_component_mapping(mapping, indent=0)
Expand All @@ -201,35 +238,6 @@ def test_indentation_levels(self):
assert result_1[0].startswith(" ") # 1 level (2 spaces)
assert result_2[0].startswith(" ") # 2 levels (4 spaces)

def test_regression_original_bug(self):
"""Regression test for the original bug where EmbeddingBridge was treated as a dict."""
# This is the exact scenario that was causing the AttributeError
mapping = {
"embed": ("embed", EmbeddingBridge),
"blocks": (
"blocks",
BlockBridge,
{
"attn": ("attn", AttentionBridge),
},
),
"unembed": ("unembed", EmbeddingBridge),
}
self.bridge.bridge.component_mapping = mapping

# This should not raise AttributeError: type object 'EmbeddingBridge' has no attribute 'items'
try:
result = self.bridge._format_component_mapping(mapping, indent=1)
# If we get here, the bug is fixed
assert len(result) == 4 # embed + blocks + attn + unembed
except AttributeError as e:
if "has no attribute 'items'" in str(e):
pytest.fail(
"Original bug still present: RemoteImport tuples being treated as BlockMapping"
)
else:
raise # Re-raise if it's a different AttributeError


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading