Skip to content

(Draft) Add DLA function to utils #466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 107 additions & 1 deletion transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from rich import print as rprint
from transformers import AutoTokenizer

from transformer_lens.FactoredMatrix import FactoredMatrix
from transformer_lens import ActivationCache, FactoredMatrix, HookedTransformer

CACHE_DIR = transformers.TRANSFORMERS_CACHE
USE_DEFAULT_VALUE = None
Expand Down Expand Up @@ -1202,6 +1202,112 @@ def get_tokens_with_bos_removed(tokenizer, tokens):
return tokens[tokens != -100].view(*bos_removed_shape)


def DLA(
model: HookedTransformer,
prompts: List[str],
answer_tokens: Int[torch.Tensor, "batch answers"],
accumulated: bool = False,
) -> (Float[torch.Tensor, "component"], List[str]):
"""Function to calculate the DLA (either accumulated or per layer) for given list of prompts and tokens.

Args:
model(HookedTransformer): model to test
prompts(List[str]): list of prompts
answer_tokens (Int[torch.Tensor, "batch answers"]) : per batch can be either single token or a pair of (correct, wrong) tokens
accumulated (bool): wheter to return the accumulated DLA or per layer

Returns:
Float[torch.Tensor, "component"] : DLA per layer
List[str] : labels for each layer
"""
assert len(prompts) == answer_tokens.shape[0]
assert answer_tokens.shape[1] == 1 or answer_tokens.shape[1] == 2
answer_residual_directions: Float[
torch.Tensor, "batch answers d_model"
] = model.tokens_to_residual_directions(answer_tokens)

if (
answer_tokens.numel() == 1
): # special case as tokens_to_residual_directions returns Float[Tensor, "d_model"]
logit_diff_directions: Float[torch.Tensor, "batch d_model"] = torch.unsqueeze(
answer_residual_directions, dim=0
)
elif answer_residual_directions.shape[1] == 1:
logit_diff_directions: Float[
torch.Tensor, "batch d_model"
] = answer_residual_directions[:, 0, :]
else:
(
correct_residual_directions,
incorrect_residual_directions,
) = answer_residual_directions.unbind(dim=1)
logit_diff_directions: Float[torch.Tensor, "batch d_model"] = (
correct_residual_directions - incorrect_residual_directions
)

def residual_stack_to_logit_diff(
residual_stack: Float[torch.Tensor, "... batch d_model"],
cache: ActivationCache,
logit_diff_directions: Float[torch.Tensor, "batch d_model"],
) -> Float[torch.Tensor, "..."]:
batch_size = residual_stack.size(-2)
scaled_residual_stack = cache.apply_ln_to_stack(
residual_stack, layer=-1, pos_slice=-1
)
return (
einops.einsum(
scaled_residual_stack,
logit_diff_directions,
"... batch d_model, batch d_model -> ...",
)
/ batch_size
)

if accumulated:
n_layers = model.cfg.n_layers
_, cache = model.run_with_cache(
prompts,
return_type=None,
names_filter=lambda x: x == get_act_name("resid_post", n_layers - 1)
or x == get_act_name("ln_final.hook_scale")
or x.endswith("resid_pre")
or x.endswith("resid_mid"),
)

accumulated_residual, labels = cache.accumulated_resid(
layer=-1, pos_slice=-1, incl_mid=True, return_labels=True
)

logit_lens_logit_diffs: Float[
torch.Tensor, "component"
] = residual_stack_to_logit_diff(
accumulated_residual, cache, logit_diff_directions
)

return logit_lens_logit_diffs, labels

else:
_, cache = model.run_with_cache(
prompts,
return_type=None,
names_filter=lambda x: x == get_act_name("ln_final.hook_scale")
or x.endswith("embed")
or x.endswith("attn_out")
or x.endswith("mlp_out"),
)

per_layer_residual, labels = cache.decompose_resid(
layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs: Float[
torch.Tensor, "component"
] = residual_stack_to_logit_diff(
per_layer_residual, cache, logit_diff_directions
)

return per_layer_logit_diffs, labels


try:
import pytest

Expand Down