diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index c3fe2963e..4d8f7cbca 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -24,12 +24,9 @@ from rich import print as rprint from transformers import AutoTokenizer -from transformer_lens import FactoredMatrix - CACHE_DIR = transformers.TRANSFORMERS_CACHE USE_DEFAULT_VALUE = None - def select_compatible_kwargs( kwargs_dict: Dict[str, Any], callable: Callable ) -> Dict[str, Any]: @@ -97,8 +94,14 @@ def get_corner(tensor, n=3): # Prints the top left corner of the tensor if isinstance(tensor, torch.Tensor): return tensor[tuple(slice(n) for _ in range(tensor.ndim))] - elif isinstance(tensor, FactoredMatrix): - return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB + else: + # pylint: disable=wrong-import-position + # isort: off + from transformer_lens import FactoredMatrix # Lazy import to stop circular dependencies + # isort: on + # pylint: enable=wrong-import-position + if isinstance(tensor, FactoredMatrix): + return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB def to_numpy(tensor):