Skip to content

Add support for model_name-less models. #603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def move_model_modules_to_device(self):
@classmethod
def from_pretrained(
cls,
model_name: str,
model_name: Union[str, None],
fold_ln: bool = True,
center_writing_weights: bool = True,
center_unembed: bool = True,
Expand Down Expand Up @@ -1076,10 +1076,10 @@ def from_pretrained(
Loaded pretrained model tiny-stories-1M into HookedTransformer

Args:
model_name: The model name - must be an element of
model_name: The model name - if a string, must be an element of
:const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias
of one. The full list of available models can be found in the docs under :doc:`model
properties</generated/model_properties_table>`.
properties</generated/model_properties_table>`. If None, you must provide hf_model.
fold_ln: Whether to fold in the LayerNorm weights to the
subsequent linear layer. This does not change the computation.

Expand Down Expand Up @@ -1195,6 +1195,8 @@ def from_pretrained(
or from_pretrained_kwargs.get("load_in_4bit", False)
), "Quantization not supported"

assert hf_model is not None or model_name is not None, "Must specify model_name or hf_model."

if hf_model is not None:
hf_cfg = hf_model.config.to_dict()
qc = hf_cfg.get("quantization_config", {})
Expand All @@ -1206,7 +1208,7 @@ def from_pretrained(
load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1"))
), "Quantization is only supported for torch versions >= 2.1.1"
assert not (
load_in_4bit and ("llama" not in model_name.lower())
load_in_4bit and not isinstance(hf_cfg, transformers.models.llama.configuration_llama.LlamaConfig)
), "Quantization is only supported for Llama models"
if load_in_4bit:
assert (
Expand All @@ -1229,9 +1231,12 @@ def from_pretrained(
) and device in ["cpu", None]:
logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.")

# Get the model name used in HuggingFace, rather than the alias.
official_model_name = loading.get_official_model_name(model_name)

if model_name is not None:
# Get the model name used in HuggingFace, rather than the alias.
official_model_name = loading.get_official_model_name(model_name)
else:
official_model_name = None

# Load the config into an HookedTransformerConfig object. If loading from a
# checkpoint, the config object will contain the information about the
# checkpoint
Expand Down Expand Up @@ -1294,7 +1299,7 @@ def from_pretrained(
if move_to_device:
model.move_model_modules_to_device()

print(f"Loaded pretrained model {model_name} into HookedTransformer")
print(f"Loaded pretrained model with {model_name=} into HookedTransformer")

return model

Expand Down
9 changes: 3 additions & 6 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ def get_checkpoint_labels(model_name: str, **kwargs):

# %% Loading state dicts
def get_pretrained_state_dict(
official_model_name: str,
official_model_name: str | None,
cfg: HookedTransformerConfig,
hf_model=None,
dtype: torch.dtype = torch.float32,
Expand All @@ -1395,18 +1395,15 @@ def get_pretrained_state_dict(
if "torch_dtype" in kwargs:
dtype = kwargs["torch_dtype"]
del kwargs["torch_dtype"]
official_model_name = get_official_model_name(official_model_name)
if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
if official_model_name is not None and official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
"trust_remote_code", False
):
logging.warning(
f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
)
kwargs["trust_remote_code"] = True
if (
official_model_name.startswith("NeelNanda")
or official_model_name.startswith("ArthurConmy")
or official_model_name.startswith("Baidicoot")
official_name is not None and official_model_name.startswith(("NeelNanda", "ArthurConmy", "Baidicoot"))
):
api = HfApi()
repo_files = api.list_repo_files(
Expand Down