diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index baf32ad05..0a0ae69b9 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1038,7 +1038,7 @@ def move_model_modules_to_device(self): @classmethod def from_pretrained( cls, - model_name: str, + model_name: Union[str, None], fold_ln: bool = True, center_writing_weights: bool = True, center_unembed: bool = True, @@ -1076,10 +1076,10 @@ def from_pretrained( Loaded pretrained model tiny-stories-1M into HookedTransformer Args: - model_name: The model name - must be an element of + model_name: The model name - if a string, must be an element of :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias of one. The full list of available models can be found in the docs under :doc:`model - properties`. + properties`. If None, you must provide hf_model. fold_ln: Whether to fold in the LayerNorm weights to the subsequent linear layer. This does not change the computation. @@ -1195,6 +1195,8 @@ def from_pretrained( or from_pretrained_kwargs.get("load_in_4bit", False) ), "Quantization not supported" + assert hf_model is not None or model_name is not None, "Must specify model_name or hf_model." + if hf_model is not None: hf_cfg = hf_model.config.to_dict() qc = hf_cfg.get("quantization_config", {}) @@ -1206,7 +1208,7 @@ def from_pretrained( load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) ), "Quantization is only supported for torch versions >= 2.1.1" assert not ( - load_in_4bit and ("llama" not in model_name.lower()) + load_in_4bit and not isinstance(hf_cfg, transformers.models.llama.configuration_llama.LlamaConfig) ), "Quantization is only supported for Llama models" if load_in_4bit: assert ( @@ -1229,9 +1231,12 @@ def from_pretrained( ) and device in ["cpu", None]: logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") - # Get the model name used in HuggingFace, rather than the alias. - official_model_name = loading.get_official_model_name(model_name) - + if model_name is not None: + # Get the model name used in HuggingFace, rather than the alias. + official_model_name = loading.get_official_model_name(model_name) + else: + official_model_name = None + # Load the config into an HookedTransformerConfig object. If loading from a # checkpoint, the config object will contain the information about the # checkpoint @@ -1294,7 +1299,7 @@ def from_pretrained( if move_to_device: model.move_model_modules_to_device() - print(f"Loaded pretrained model {model_name} into HookedTransformer") + print(f"Loaded pretrained model with {model_name=} into HookedTransformer") return model diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index f8d9d2c19..f093e1574 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1375,7 +1375,7 @@ def get_checkpoint_labels(model_name: str, **kwargs): # %% Loading state dicts def get_pretrained_state_dict( - official_model_name: str, + official_model_name: str | None, cfg: HookedTransformerConfig, hf_model=None, dtype: torch.dtype = torch.float32, @@ -1395,8 +1395,7 @@ def get_pretrained_state_dict( if "torch_dtype" in kwargs: dtype = kwargs["torch_dtype"] del kwargs["torch_dtype"] - official_model_name = get_official_model_name(official_model_name) - if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( + if official_model_name is not None and official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( "trust_remote_code", False ): logging.warning( @@ -1404,9 +1403,7 @@ def get_pretrained_state_dict( ) kwargs["trust_remote_code"] = True if ( - official_model_name.startswith("NeelNanda") - or official_model_name.startswith("ArthurConmy") - or official_model_name.startswith("Baidicoot") + official_name is not None and official_model_name.startswith(("NeelNanda", "ArthurConmy", "Baidicoot")) ): api = HfApi() repo_files = api.list_repo_files(