Skip to content

Commit

Permalink
fix extern dataset update
Browse files Browse the repository at this point in the history
  • Loading branch information
Benedikt Fuchs committed Jul 31, 2023
1 parent 270ef05 commit c492abf
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 113 deletions.
225 changes: 113 additions & 112 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def truncate_hidden_states(hidden_states: torch.Tensor, input_ids: torch.Tensor)

@torch.jit.script_if_tracing
def combine_strided_tensors(
hidden_states: torch.Tensor,
overflow_to_sample_mapping: torch.Tensor,
half_stride: int,
max_length: int,
default_value: int,
hidden_states: torch.Tensor,
overflow_to_sample_mapping: torch.Tensor,
half_stride: int,
max_length: int,
default_value: int,
) -> torch.Tensor:
_, counts = torch.unique(overflow_to_sample_mapping, sorted=True, return_counts=True)
sentence_count = int(overflow_to_sample_mapping.max().item() + 1)
Expand All @@ -94,9 +94,9 @@ def combine_strided_tensors(
selected_sentences = hidden_states[overflow_to_sample_mapping == sentence_id]
if selected_sentences.size(0) > 1:
start_part = selected_sentences[0, : half_stride + 1]
mid_part = selected_sentences[:, half_stride + 1: max_length - 1 - half_stride]
mid_part = selected_sentences[:, half_stride + 1 : max_length - 1 - half_stride]
mid_part = torch.reshape(mid_part, (mid_part.shape[0] * mid_part.shape[1],) + mid_part.shape[2:])
end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1:]
end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1 :]
sentence_hidden_state = torch.cat((start_part, mid_part, end_part), dim=0)
sentence_hidden_states[sentence_id, : sentence_hidden_state.shape[0]] = torch.cat(
(start_part, mid_part, end_part), dim=0
Expand All @@ -109,11 +109,11 @@ def combine_strided_tensors(

@torch.jit.script_if_tracing
def fill_masked_elements(
all_token_embeddings: torch.Tensor,
sentence_hidden_states: torch.Tensor,
mask: torch.Tensor,
word_ids: torch.Tensor,
lengths: torch.LongTensor,
all_token_embeddings: torch.Tensor,
sentence_hidden_states: torch.Tensor,
mask: torch.Tensor,
word_ids: torch.Tensor,
lengths: torch.LongTensor,
):
for i in torch.arange(int(all_token_embeddings.shape[0])):
r = insert_missing_embeddings(sentence_hidden_states[i][mask[i] & (word_ids[i] >= 0)], word_ids[i], lengths[i])
Expand All @@ -123,7 +123,7 @@ def fill_masked_elements(

@torch.jit.script_if_tracing
def insert_missing_embeddings(
token_embeddings: torch.Tensor, word_id: torch.Tensor, length: torch.LongTensor
token_embeddings: torch.Tensor, word_id: torch.Tensor, length: torch.LongTensor
) -> torch.Tensor:
# in some cases we need to insert zero vectors for tokens without embedding.
if token_embeddings.shape[0] == 0:
Expand Down Expand Up @@ -166,10 +166,10 @@ def insert_missing_embeddings(

@torch.jit.script_if_tracing
def fill_mean_token_embeddings(
all_token_embeddings: torch.Tensor,
sentence_hidden_states: torch.Tensor,
word_ids: torch.Tensor,
token_lengths: torch.Tensor,
all_token_embeddings: torch.Tensor,
sentence_hidden_states: torch.Tensor,
word_ids: torch.Tensor,
token_lengths: torch.Tensor,
):
for i in torch.arange(all_token_embeddings.shape[0]):
for _id in torch.arange(token_lengths[i]): # type: ignore[call-overload]
Expand All @@ -196,7 +196,7 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths:


def _legacy_reconstruct_word_ids(
embedding: "TransformerBaseEmbeddings", flair_tokens: List[List[str]]
embedding: "TransformerBaseEmbeddings", flair_tokens: List[List[str]]
) -> List[List[Optional[int]]]:
word_ids_list = []
max_len = 0
Expand Down Expand Up @@ -307,25 +307,25 @@ class TransformerBaseEmbeddings(Embeddings[Sentence]):
"""

def __init__(
self,
name: str,
tokenizer: PreTrainedTokenizer,
embedding_length: int,
context_length: int,
context_dropout: float,
respect_document_boundaries: bool,
stride: int,
allow_long_sentences: bool,
fine_tune: bool,
truncate: bool,
use_lang_emb: bool,
is_document_embedding: bool = False,
is_token_embedding: bool = False,
force_device: Optional[torch.device] = None,
force_max_length: bool = False,
feature_extractor: Optional[FeatureExtractionMixin] = None,
needs_manual_ocr: Optional[bool] = None,
use_context_separator: bool = True,
self,
name: str,
tokenizer: PreTrainedTokenizer,
embedding_length: int,
context_length: int,
context_dropout: float,
respect_document_boundaries: bool,
stride: int,
allow_long_sentences: bool,
fine_tune: bool,
truncate: bool,
use_lang_emb: bool,
is_document_embedding: bool = False,
is_token_embedding: bool = False,
force_device: Optional[torch.device] = None,
force_max_length: bool = False,
feature_extractor: Optional[FeatureExtractionMixin] = None,
needs_manual_ocr: Optional[bool] = None,
use_context_separator: bool = True,
) -> None:
self.name = name
super().__init__()
Expand Down Expand Up @@ -473,32 +473,32 @@ def prepare_tensors(self, sentences: List[Sentence], device: Optional[torch.devi

# random check some tokens to save performance.
if (self.needs_manual_ocr or self.tokenizer_needs_ocr_boxes) and not all(
[
flair_tokens[0][0].has_metadata("bbox"),
flair_tokens[0][-1].has_metadata("bbox"),
flair_tokens[-1][0].has_metadata("bbox"),
flair_tokens[-1][-1].has_metadata("bbox"),
]
[
flair_tokens[0][0].has_metadata("bbox"),
flair_tokens[0][-1].has_metadata("bbox"),
flair_tokens[-1][0].has_metadata("bbox"),
flair_tokens[-1][-1].has_metadata("bbox"),
]
):
raise ValueError(f"The embedding '{self.name}' requires the ocr 'bbox' set as metadata on all tokens.")

if self.feature_extractor is not None and not all(
[
sentences[0].has_metadata("image"),
sentences[-1].has_metadata("image"),
]
[
sentences[0].has_metadata("image"),
sentences[-1].has_metadata("image"),
]
):
raise ValueError(f"The embedding '{self.name}' requires the 'image' set as metadata for all sentences.")

return self.__build_transformer_model_inputs(sentences, offsets, lengths, flair_tokens, device)

def __build_transformer_model_inputs(
self,
sentences: List[Sentence],
offsets: List[int],
sentence_lengths: List[int],
flair_tokens: List[List[Token]],
device: torch.device,
self,
sentences: List[Sentence],
offsets: List[int],
sentence_lengths: List[int],
flair_tokens: List[List[Token]],
device: torch.device,
):
tokenizer_kwargs: Dict[str, Any] = {}
if self.tokenizer_needs_ocr_boxes:
Expand Down Expand Up @@ -559,7 +559,7 @@ def __build_transformer_model_inputs(
sentence_idx = 0
for sentence, part_length in zip(sentences, sentence_part_lengths):
lang_id = lang2id.get(sentence.get_language_code(), 0)
model_kwargs["langs"][sentence_idx: sentence_idx + part_length] = lang_id
model_kwargs["langs"][sentence_idx : sentence_idx + part_length] = lang_id
sentence_idx += part_length

if "bbox" in batch_encoding:
Expand Down Expand Up @@ -801,12 +801,12 @@ def collect_dynamic_axes(cls, embedding: "TransformerEmbeddings", tensors):

@classmethod
def export_from_embedding(
cls,
path: Union[str, Path],
embedding: "TransformerEmbeddings",
example_sentences: List[Sentence],
opset_version: int = 14,
providers: Optional[List] = None,
cls,
path: Union[str, Path],
embedding: "TransformerEmbeddings",
example_sentences: List[Sentence],
opset_version: int = 14,
providers: Optional[List] = None,
):
path = str(path)
example_tensors = embedding.prepare_tensors(example_sentences)
Expand Down Expand Up @@ -899,7 +899,7 @@ def create_from_embedding(cls, module: ScriptModule, embedding: "TransformerEmbe

@classmethod
def parameter_to_list(
cls, embedding: "TransformerEmbeddings", wrapper: torch.nn.Module, sentences: List[Sentence]
cls, embedding: "TransformerEmbeddings", wrapper: torch.nn.Module, sentences: List[Sentence]
) -> Tuple[List[str], List[torch.Tensor]]:
tensors = embedding.prepare_tensors(sentences)
param_names = list(inspect.signature(wrapper.forward).parameters.keys())
Expand All @@ -912,35 +912,35 @@ def parameter_to_list(
@register_embeddings
class TransformerJitWordEmbeddings(TokenEmbeddings, TransformerJitEmbeddings):
def __init__(
self,
**kwargs,
self,
**kwargs,
) -> None:
TransformerJitEmbeddings.__init__(self, **kwargs)


@register_embeddings
class TransformerJitDocumentEmbeddings(DocumentEmbeddings, TransformerJitEmbeddings):
def __init__(
self,
**kwargs,
self,
**kwargs,
) -> None:
TransformerJitEmbeddings.__init__(self, **kwargs)


@register_embeddings
class TransformerOnnxWordEmbeddings(TokenEmbeddings, TransformerOnnxEmbeddings):
def __init__(
self,
**kwargs,
self,
**kwargs,
) -> None:
TransformerOnnxEmbeddings.__init__(self, **kwargs)


@register_embeddings
class TransformerOnnxDocumentEmbeddings(DocumentEmbeddings, TransformerOnnxEmbeddings):
def __init__(
self,
**kwargs,
self,
**kwargs,
) -> None:
TransformerOnnxEmbeddings.__init__(self, **kwargs)

Expand All @@ -950,27 +950,27 @@ class TransformerEmbeddings(TransformerBaseEmbeddings):
onnx_cls: Type[TransformerOnnxEmbeddings] = TransformerOnnxEmbeddings

def __init__(
self,
model: str = "bert-base-uncased",
fine_tune: bool = True,
layers: str = "-1",
layer_mean: bool = True,
subtoken_pooling: str = "first",
cls_pooling: str = "cls",
is_token_embedding: bool = True,
is_document_embedding: bool = True,
allow_long_sentences: bool = False,
use_context: Union[bool, int] = False,
respect_document_boundaries: bool = True,
context_dropout: float = 0.5,
saved_config: Optional[PretrainedConfig] = None,
tokenizer_data: Optional[BytesIO] = None,
feature_extractor_data: Optional[BytesIO] = None,
name: Optional[str] = None,
force_max_length: bool = False,
needs_manual_ocr: Optional[bool] = None,
use_context_separator: bool = True,
**kwargs,
self,
model: str = "bert-base-uncased",
fine_tune: bool = True,
layers: str = "-1",
layer_mean: bool = True,
subtoken_pooling: str = "first",
cls_pooling: str = "cls",
is_token_embedding: bool = True,
is_document_embedding: bool = True,
allow_long_sentences: bool = False,
use_context: Union[bool, int] = False,
respect_document_boundaries: bool = True,
context_dropout: float = 0.5,
saved_config: Optional[PretrainedConfig] = None,
tokenizer_data: Optional[BytesIO] = None,
feature_extractor_data: Optional[BytesIO] = None,
name: Optional[str] = None,
force_max_length: bool = False,
needs_manual_ocr: Optional[bool] = None,
use_context_separator: bool = True,
**kwargs,
) -> None:
self.instance_parameters = self.get_instance_parameters(locals=locals())
del self.instance_parameters["saved_config"]
Expand Down Expand Up @@ -1107,14 +1107,15 @@ def embedding_length(self) -> int:

return self.embedding_length_internal


def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if transformers.__version__ >= Version(4, 31, 0):
assert isinstance(state_dict, dict)
state_dict.pop(f"{prefix}model.embeddings.position_ids", None)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def _has_initial_cls_token(self) -> bool:
# most models have CLS token as last token (GPT-1, GPT-2, TransfoXL, XLNet, XLM), but BERT is initial
Expand Down Expand Up @@ -1248,23 +1249,23 @@ def to_params(self):
def _can_document_embedding_shortcut(self):
# cls first pooling can be done without recreating sentence hidden states
return (
self.document_embedding
and not self.token_embedding
and self.cls_pooling == "cls"
and self.initial_cls_token
self.document_embedding
and not self.token_embedding
and self.cls_pooling == "cls"
and self.initial_cls_token
)

def forward(
self,
input_ids: torch.Tensor,
sub_token_lengths: Optional[torch.LongTensor] = None,
token_lengths: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
overflow_to_sample_mapping: Optional[torch.Tensor] = None,
word_ids: Optional[torch.Tensor] = None,
langs: Optional[torch.Tensor] = None,
bbox: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
self,
input_ids: torch.Tensor,
sub_token_lengths: Optional[torch.LongTensor] = None,
token_lengths: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
overflow_to_sample_mapping: Optional[torch.Tensor] = None,
word_ids: Optional[torch.Tensor] = None,
langs: Optional[torch.Tensor] = None,
bbox: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
):
model_kwargs = {}
if langs is not None:
Expand Down Expand Up @@ -1353,8 +1354,8 @@ def forward(
word_ids,
token_lengths,
)
all_token_embeddings[:, :, sentence_hidden_states.shape[2]:] = fill_masked_elements(
all_token_embeddings[:, :, sentence_hidden_states.shape[2]:],
all_token_embeddings[:, :, sentence_hidden_states.shape[2] :] = fill_masked_elements(
all_token_embeddings[:, :, sentence_hidden_states.shape[2] :],
sentence_hidden_states,
last_mask,
word_ids,
Expand All @@ -1374,7 +1375,7 @@ def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]:
return self.forward(**tensors)

def export_onnx(
self, path: Union[str, Path], example_sentences: List[Sentence], **kwargs
self, path: Union[str, Path], example_sentences: List[Sentence], **kwargs
) -> TransformerOnnxEmbeddings:
"""Export TransformerEmbeddings to OnnxFormat.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def test_masakhane_corpus(tasks_base_path):
"bam": {"train": 4462, "dev": 638, "test": 1274},
"bbj": {"train": 3384, "dev": 483, "test": 966},
"ewe": {"train": 3505, "dev": 501, "test": 1001},
"fon": {"train": 4343, "dev": 621, "test": 1240},
"fon": {"train": 4343, "dev": 623, "test": 1228},
"hau": {"train": 5716, "dev": 816, "test": 1633},
"ibo": {"train": 7634, "dev": 1090, "test": 2181},
"kin": {"train": 7825, "dev": 1118, "test": 2235},
Expand Down

0 comments on commit c492abf

Please sign in to comment.