Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Avg Pooling in the Entity Linker #3122

Closed
aynetdia opened this issue Feb 23, 2023 · 0 comments · Fixed by #3123
Closed

[Bug]: Avg Pooling in the Entity Linker #3122

aynetdia opened this issue Feb 23, 2023 · 0 comments · Fixed by #3123
Assignees
Labels
bug Something isn't working

Comments

@aynetdia
Copy link
Collaborator

aynetdia commented Feb 23, 2023

Describe the bug

A runtime error is raised upon prediction when using "average" as the pooling operation in the Entity Linker

To Reproduce

from flair.data import Corpus
from flair.datasets import NEL_ENGLISH_TWEEKI
from flair.embeddings import TransformerWordEmbeddings
from flair.models import EntityLinker

corpus: Corpus = NEL_ENGLISH_TWEEKI(sample_missing_splits=False)

embeddings = TransformerWordEmbeddings(
    model="distilbert-base-uncased",
    fine_tune=True,
)

entity_linker = EntityLinker(
    embeddings=embeddings,
    label_dictionary=corpus.make_label_dictionary(label_type="nel"),
    label_type="nel",
    pooling_operation="average",
)

entity_linker.predict(corpus.train[0])

Expected behaivor

The Entity Linker should be able to perform average pooling without any issues, like in the case of other pooling options.

Logs and Stack traces

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [1], line 20
      8 embeddings = TransformerWordEmbeddings(
      9     model="distilbert-base-uncased",
     10     fine_tune=True,
     11 )
     13 entity_linker = EntityLinker(
     14     embeddings=embeddings,
     15     label_dictionary=corpus.make_label_dictionary(label_type="nel"),
     16     label_type="nel",
     17     pooling_operation="average",
     18 )
---> 20 entity_linker.predict(corpus.train[0])

File ~/projects/flair_forked/flair/nn/model.py:826, in DefaultClassifier.predict(self, sentences, mini_batch_size, return_probabilities_for_all_classes, verbose, label_name, return_loss, embedding_storage_mode)
    824 # pass data points through network and decode
    825 data_point_tensor = self._encode_data_points(batch, data_points)
--> 826 scores = self.decoder(data_point_tensor)
    827 scores = self._mask_scores(scores, data_points)
    829 # if anything could possibly be predicted

File ~/miniforge3/envs/flair/lib/python3.9/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniforge3/envs/flair/lib/python3.9/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x4 and 768x650)

Screenshots

No response

Additional Context

No response

Environment

Versions:

Flair

0.11.3

Pytorch

1.13.0

Transformers

4.24.0

GPU

False

@aynetdia aynetdia added the bug Something isn't working label Feb 23, 2023
@aynetdia aynetdia self-assigned this Feb 23, 2023
alanakbik added a commit that referenced this issue Mar 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant