Skip to content

Commit

Permalink
Merge pull request #3058 from flairNLP/fix-tart-prediction
Browse files Browse the repository at this point in the history
Change indexing in TARSTagger predict
  • Loading branch information
alanakbik authored Jan 19, 2023
2 parents ef926e6 + 96f568e commit 1577601
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 6 additions & 3 deletions flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def predict(
for tuple in sorted_x:
# get the span and its label
label = tuple[0]
# label = span.get_labels("tars_temp_label")[0].value

label_length = (
0 if not self.prefix else len(label.value.split(" ")) + len(self.separator.split(" "))
)
Expand All @@ -560,18 +560,21 @@ def predict(
if corresponding_token is None:
tag_this = False
continue
if token.idx in already_set_indices:
if corresponding_token.idx in already_set_indices:
tag_this = False
continue

# only add if all tokens have no label
if tag_this:
already_set_indices.extend(token.idx for token in label.data_point)
# make and add a corresponding predicted span
predicted_span = Span(
[sentence.get_token(token.idx - label_length) for token in label.data_point]
)
predicted_span.add_label(label_name, value=label.value, score=label.score)

# set indices so that no token can be tagged twice
already_set_indices.extend(token.idx for token in predicted_span)

# clearing token embeddings to save memory
store_embeddings(batch, storage_mode=embedding_storage_mode)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ filterwarnings = [
'ignore:bilinear is deprecated and will be removed in Pillow 10', # huggingface layoutlmv2 has deprecated calls.
'ignore:nearest is deprecated and will be removed in Pillow 10', # huggingface layoutlmv2 has deprecated calls.
'ignore:The `device` argument is deprecated and will be removed in v5 of Transformers.', # hf layoutlmv3 calls deprecated hf.
"ignore:the imp module is deprecated:DeprecationWarning:past" # ignore DeprecationWarning from hyperopt dependency
"ignore:the imp module is deprecated:DeprecationWarning:past", # ignore DeprecationWarning from hyperopt dependency
"ignore:.*imp module.*:DeprecationWarning", # ignore DeprecationWarnings that involve imp module
]
markers = [
"integration",
Expand Down

0 comments on commit 1577601

Please sign in to comment.