From a3dbd5f8801b2d857a092904a656e376aeb1f48c Mon Sep 17 00:00:00 2001 From: Arthur Conmy <35957051+ArthurConmy@users.noreply.github.com> Date: Fri, 22 Mar 2024 18:04:53 +0000 Subject: [PATCH 1/9] Remove FactoredMatrix.py<->utils.py circular dependency --- transformer_lens/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index c3fe2963e..c74aa76fd 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -24,8 +24,6 @@ from rich import print as rprint from transformers import AutoTokenizer -from transformer_lens import FactoredMatrix - CACHE_DIR = transformers.TRANSFORMERS_CACHE USE_DEFAULT_VALUE = None @@ -97,7 +95,7 @@ def get_corner(tensor, n=3): # Prints the top left corner of the tensor if isinstance(tensor, torch.Tensor): return tensor[tuple(slice(n) for _ in range(tensor.ndim))] - elif isinstance(tensor, FactoredMatrix): + elif isinstance(tensor, "FactoredMatrix"): return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB From 46a1e2de8c89b1e2d03fb2359d254e60b1e187c2 Mon Sep 17 00:00:00 2001 From: Arthur Conmy <35957051+ArthurConmy@users.noreply.github.com> Date: Fri, 22 Mar 2024 18:13:50 +0000 Subject: [PATCH 2/9] Update utils.py --- transformer_lens/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index c74aa76fd..e7b3cf0ce 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -95,8 +95,10 @@ def get_corner(tensor, n=3): # Prints the top left corner of the tensor if isinstance(tensor, torch.Tensor): return tensor[tuple(slice(n) for _ in range(tensor.ndim))] - elif isinstance(tensor, "FactoredMatrix"): - return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB + else + from transformer_lens import FactoredMatrix # Lazy import to stop circular dependencies + isinstance(tensor, FactoredMatrix): + return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB def to_numpy(tensor): From c758ebe8f3c534db7adca878a68ff1a476f4e025 Mon Sep 17 00:00:00 2001 From: Arthur Conmy <35957051+ArthurConmy@users.noreply.github.com> Date: Fri, 22 Mar 2024 18:14:18 +0000 Subject: [PATCH 3/9] Update utils.py --- transformer_lens/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index e7b3cf0ce..e3aa7c64a 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -95,7 +95,7 @@ def get_corner(tensor, n=3): # Prints the top left corner of the tensor if isinstance(tensor, torch.Tensor): return tensor[tuple(slice(n) for _ in range(tensor.ndim))] - else + else: from transformer_lens import FactoredMatrix # Lazy import to stop circular dependencies isinstance(tensor, FactoredMatrix): return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB From 3a0a4d2aa134ed8d185c01be36e09442841f6195 Mon Sep 17 00:00:00 2001 From: Arthur Conmy <35957051+ArthurConmy@users.noreply.github.com> Date: Fri, 22 Mar 2024 18:15:00 +0000 Subject: [PATCH 4/9] Update utils.py --- transformer_lens/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index e3aa7c64a..deb6ae71b 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -97,7 +97,7 @@ def get_corner(tensor, n=3): return tensor[tuple(slice(n) for _ in range(tensor.ndim))] else: from transformer_lens import FactoredMatrix # Lazy import to stop circular dependencies - isinstance(tensor, FactoredMatrix): + if isinstance(tensor, FactoredMatrix): return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB From 28775d1402fd190a44457b1c87d068f3e2cb4677 Mon Sep 17 00:00:00 2001 From: Arthur Conmy <35957051+ArthurConmy@users.noreply.github.com> Date: Fri, 22 Mar 2024 18:24:46 +0000 Subject: [PATCH 5/9] Update utils.py --- transformer_lens/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index deb6ae71b..92c74bce5 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -95,8 +95,10 @@ def get_corner(tensor, n=3): # Prints the top left corner of the tensor if isinstance(tensor, torch.Tensor): return tensor[tuple(slice(n) for _ in range(tensor.ndim))] - else: + else: + # pylint: disable=wrong-import-position from transformer_lens import FactoredMatrix # Lazy import to stop circular dependencies + # pylint: enable=wrong-import-position if isinstance(tensor, FactoredMatrix): return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB From 9378dde09a438678be4e324031925dff227221bc Mon Sep 17 00:00:00 2001 From: Arthur Conmy <35957051+ArthurConmy@users.noreply.github.com> Date: Fri, 22 Mar 2024 19:26:28 +0000 Subject: [PATCH 6/9] Update utils.py --- transformer_lens/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 92c74bce5..94594e301 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -9,7 +9,9 @@ import re import shutil from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast +from typing import ( + Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast +) import einops import numpy as np @@ -22,12 +24,12 @@ from huggingface_hub import hf_hub_download from jaxtyping import Float, Int from rich import print as rprint +from transformer_lens import FactoredMatrix from transformers import AutoTokenizer CACHE_DIR = transformers.TRANSFORMERS_CACHE USE_DEFAULT_VALUE = None - def select_compatible_kwargs( kwargs_dict: Dict[str, Any], callable: Callable ) -> Dict[str, Any]: From 3fad04845726513d61e4996e35f46079fc2de499 Mon Sep 17 00:00:00 2001 From: Arthur Conmy <35957051+ArthurConmy@users.noreply.github.com> Date: Fri, 22 Mar 2024 19:27:06 +0000 Subject: [PATCH 7/9] Update utils.py --- transformer_lens/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 94594e301..a856c8d43 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -24,7 +24,6 @@ from huggingface_hub import hf_hub_download from jaxtyping import Float, Int from rich import print as rprint -from transformer_lens import FactoredMatrix from transformers import AutoTokenizer CACHE_DIR = transformers.TRANSFORMERS_CACHE From 2b1629c882d5cd34821391dc23063818cdc97d56 Mon Sep 17 00:00:00 2001 From: Arthur Conmy <35957051+ArthurConmy@users.noreply.github.com> Date: Fri, 22 Mar 2024 20:06:27 +0000 Subject: [PATCH 8/9] Update utils.py --- transformer_lens/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index a856c8d43..9ad1434c4 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -98,7 +98,9 @@ def get_corner(tensor, n=3): return tensor[tuple(slice(n) for _ in range(tensor.ndim))] else: # pylint: disable=wrong-import-position + # isort: off from transformer_lens import FactoredMatrix # Lazy import to stop circular dependencies + # isort: on # pylint: enable=wrong-import-position if isinstance(tensor, FactoredMatrix): return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB From 3fc72bfaf3fcfee41e06fdb71c90b327ecb53e22 Mon Sep 17 00:00:00 2001 From: Arthur Conmy <35957051+ArthurConmy@users.noreply.github.com> Date: Fri, 22 Mar 2024 20:17:40 +0000 Subject: [PATCH 9/9] Update utils.py --- transformer_lens/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 9ad1434c4..4d8f7cbca 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -9,9 +9,7 @@ import re import shutil from copy import deepcopy -from typing import ( - Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast -) +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast import einops import numpy as np