Skip to content

Commit

Permalink
fix tok length issue
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jul 7, 2024
1 parent 8227d86 commit cfd5e24
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 28 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ numpy==1.23.5
pydantic
torchinfo
conllu
genalog
pandarallel
cohere
replicate
Expand Down
2 changes: 1 addition & 1 deletion scripts/export_to_onnx_charbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from transformers import AutoModelForTokenClassification, HfArgumentParser

import wtpsplit # noqa

import wtpsplit.models # noqa

@dataclass
class Args:
Expand Down
38 changes: 20 additions & 18 deletions scripts/export_to_onnx_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

@dataclass
class Args:
model_name_or_path: str = "segment-any-text/sat-12l-no-limited-lookahead"
output_dir: str = "sat-12l-no-limited-lookahead"
device: str = "cpu"
model_name_or_path: str = "segment-any-text/sat-1l-sm"
output_dir: str = "sat-1l-sm"
device: str = "cuda"
# TODO: lora merging here


Expand All @@ -24,15 +24,15 @@ class Args:
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)

model = AutoModelForTokenClassification.from_pretrained(args.model_name_or_path, force_download=True)
# model = model.half() # CUDA ONLY!
model = AutoModelForTokenClassification.from_pretrained(args.model_name_or_path, force_download=False)
model = model.half() # CUDA ONLY!
model = model.to(args.device)

torch.onnx.export(
model,
{
"attention_mask": torch.zeros((1, 14), dtype=torch.long, device=args.device),
"input_ids": torch.zeros((1, 14), dtype=torch.long, device=args.device),
"attention_mask": torch.zeros((1, 1), dtype=torch.float16, device=args.device),
"input_ids": torch.zeros((1, 1), dtype=torch.int64, device=args.device),
},
output_dir / "model.onnx",
verbose=True,
Expand All @@ -41,21 +41,23 @@ class Args:
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
"attention_mask": {0: "batch", 1: "sequence"},
"logits": {0: "batch", 1: "sequence"}
"logits": {0: "batch", 1: "sequence"},
},
# opset_version=11
)

# m = optimize_model(
# str(output_dir / "model.onnx"),
# model_type="bert",
# optimization_options=None,
# opt_level=0,
# use_gpu=False,
# )
m = optimize_model(
str(output_dir / "model.onnx"),
model_type="bert",
num_heads=0,
hidden_size=0,
optimization_options=None,
opt_level=0,
use_gpu=False,
)

# optimized_model_path = output_dir / "model_optimized.onnx"
# onnx.save_model(m.model, optimized_model_path)
optimized_model_path = output_dir / "model_optimized.onnx"
onnx.save_model(m.model, optimized_model_path)

onnx_model = onnx.load(output_dir / "model.onnx")
onnx.checker.check_model(onnx_model, full_check=True)
onnx.checker.check_model(onnx_model, full_check=True)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="wtpsplit",
version="2.0.4",
version="2.0.5",
packages=find_packages(),
description="Universal Robust, Efficient and Adaptable Sentence Segmentation",
author="Markus Frohmann, Igor Sterner, Benjamin Minixhofer",
Expand Down
2 changes: 1 addition & 1 deletion wtpsplit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from wtpsplit.extract import BertCharORTWrapper, PyTorchWrapper, extract
from wtpsplit.utils import Constants, indices_to_sentences, sigmoid, token_to_char_probs

__version__ = "2.0.4"
__version__ = "2.0.5"

warnings.simplefilter("default", DeprecationWarning) # show by default
warnings.simplefilter("ignore", category=FutureWarning) # for tranformers
Expand Down
16 changes: 10 additions & 6 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def __getattr__(self, name):
def __call__(self, input_ids, attention_mask):
logits = self.ort_session.run(
output_names=["logits"],
input_feed={"attention_mask": attention_mask.astype(np.int64), "input_ids": input_ids.astype(np.int64)},
input_feed={
"attention_mask": attention_mask.astype(np.int64),
"input_ids": input_ids.astype(np.float16),
}, # .astype(np.int64)},
)[0]

return {"logits": logits}
Expand All @@ -71,9 +74,9 @@ def __call__(self, attention_mask, hashed_ids=None, language_ids=None, input_ids
input_ids=torch.from_numpy(input_ids).to(self.model.device) if input_ids is not None else None,
hashed_ids=torch.from_numpy(hashed_ids).to(self.model.device) if hashed_ids is not None else None,
attention_mask=torch.from_numpy(attention_mask).to(self.model.device),
language_ids=torch.from_numpy(language_ids).to(self.model.device)
if language_ids is not None
else None,
language_ids=(
torch.from_numpy(language_ids).to(self.model.device) if language_ids is not None else None
),
)["logits"]
.cpu()
.numpy()
Expand Down Expand Up @@ -124,8 +127,9 @@ def extract(
text_lengths = [len(text) for text in batch_of_texts]
# reduce block size if possible
block_size = min(max_block_size, max(text_lengths))
if use_subwords and block_size == 512:
block_size -= 2 # account for CLS and SEP tokens
if use_subwords and block_size > 510:
overflow_length = block_size - 510
block_size -= overflow_length # account for CLS and SEP tokens

# make sure block_size is a multiple of downsampling rate
downsampling_rate = getattr(model.config, "downsampling_rate", 1)
Expand Down

0 comments on commit cfd5e24

Please sign in to comment.