Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scriptable Tokenizer for Text Classification Example #1691

Merged
merged 10 commits into from
Jul 12, 2022
8 changes: 7 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* [Serving torchvision image classification models](#serving-image-classification-models)
* [Serving custom model with custom service handler](#serving-custom-model-with-custom-service-handler)
* [Serving text classification model](#serving-text-classification-model)
* [Serving text classification model with scriptable tokenizer](#serving-text-classification-model-with-scriptable-tokenzier)
* [Serving object detection model](#serving-object-detection-model)
* [Serving image segmentation model](#serving-image-segmentation-model)
* [Serving huggingface transformers model](#serving-huggingface-transformers)
Expand Down Expand Up @@ -48,7 +49,7 @@ Following are the steps to create a torch-model-archive (.mar) to execute an eag

```bash
torch-model-archiver --model-name <model_name> --version <model_version_number> --serialized-file <path_to_executable_script_module> --extra-files <path_to_index_to_name_json_file> --handler <path_to_custom_handler_or_default_handler_name>
```
```

## Serving image classification models
The following example demonstrates how to create image classifier model archive, serve it on TorchServe and run image prediction using TorchServe's default image_classifier handler :
Expand All @@ -67,6 +68,11 @@ The following example demonstrates how to create and serve a custom text_classif

* [Text classification example](text_classification)

## Serving text classification model with scriptable tokenzier

This example shows how to combine a text classification model with a scriptable tokenizer into a single, scripted artifact to serve with TorchServe. A scriptable tokenizer is a tokenizer compatible with TorchScript.
* [Scriptable Tokenizer example with scriptable tokenizer](text_classification_with_scriptable_tokenizer)

## Serving object detection model

The following example demonstrates how to create and serve a pretrained fast-rcnn NN model with default object_detector handler provided by TorchServe :
Expand Down
66 changes: 66 additions & 0 deletions examples/text_classification_with_scriptable_tokenizer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Text Classfication using a Scriptable Tokenizer

TorchScript is a way to serialize and optimize your PyTorch models.
A scriptable tokenizer is a special tokenizer which is compatible with [TorchScript's compiler](https://pytorch.org/docs/stable/jit.html) so that it can be jointly serialized with a PyTorch model.
When deploying an NLP model it is important to use the same tokenizer during training and inference to achieve the same model accuracy in both phases of the model live cycle.
Using a different tokenizer for inference than during training can decrease the model performance significantly.
Thus, is can be beneficial to combine the tokenizer together with the model into a single deployment artifact as it reduces the amount of preprocessing code in the handler leading to less synchronization effort between training and inference code bases.
This example shows how to combine a text classification model with a scriptable tokenizer into a single artifact and deploy it with TorchServe.
For demonstration purposes we use a pretrained model as created in this tutorial:

https://github.com/pytorch/text/blob/main/examples/tutorials/sst2_classification_non_distributed.py


# Training the Model

To train the model we need to follow the steps described in this [this tutorial](https://github.com/pytorch/text/blob/main/examples/tutorials/sst2_classification_non_distributed.py) and export the model weight into a ```model.pt``` file.
To use the SST-2 dataset torchtext requires torch.data to be installed.
This can be achieved with pip by running:

```
pip install torchdata
```

Or conda by runnning:

```
conda install -c pytorch torchdata
```

Subsequently, we need to add the command ```torch.save(model.state_dict(), "model.pt")``` at the end of the training script and then run it with:

```bash
python sst2_classification_non_distributed.py
```

A pretrained ```model.pt``` is also available for download [here](https://bert-mar-file.s3.us-west-2.amazonaws.com/text_classification_with_scriptable_tokenizer/model.pt).
The trained model can then be combined and compiled with TorchScript using the script_tokenizer_and_model.py script. Here ```model.pt``` are the model weights saved after training and ```model_jit.pt``` is the combination of tokenizer and model compiled with TorchScript.

```bash
python script_tokenizer_and_model.py model.pt model_jit.pt
```


# Serve the Text Classification Model on TorchServe

* Create a torch model archive using the torch-model-archiver utility to archive the file created above.

```bash
torch-model-archiver --model-name scriptable_tokenizer --version 1.0 --serialized-file model_jit.pt --handler handler.py --extra-files "index_to_name.json"
```

* Register the model on TorchServe using the above model archive file and run a classification

```bash
mkdir model_store
mv scriptable_tokenizer.mar model_store/
torchserve --start --model-store model_store --models my_tc=scriptable_tokenizer.mar
curl http://127.0.0.1:8080/predictions/my_tc -T sample_text.txt
```
* Expected Output:
```
{
"Negative": 0.0972590446472168,
"Positive": 0.9027408957481384
}
```
113 changes: 113 additions & 0 deletions examples/text_classification_with_scriptable_tokenizer/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Module for text classification with scriptable tokenizer
DOES NOT SUPPORT BATCH!
"""
import logging
from abc import ABC

import torch
import torch.nn.functional as F

# Necessary to successfully load the model (see https://github.com/pytorch/text/issues/1793)
import torchtext # nopycln: import

from ts.torch_handler.base_handler import BaseHandler
from ts.utils.util import CLEANUP_REGEX, map_class_to_label

logger = logging.getLogger(__name__)


def remove_html_tags(text):
"""
Removes html tags
"""
clean_text = CLEANUP_REGEX.sub("", text)
return clean_text


class CustomTextClassifier(BaseHandler, ABC):
"""
TextClassifier handler class. This handler takes a text (string) and
as input and returns the classification text based on the model vocabulary.
Because the predefined TextHandler in ts/torch_handler defines unnecessary
steps like loading a vocabulary file for the tokenizer, we define our handler
starting from BaseHandler.
"""

def preprocess(self, data):
"""
Tokenization is dealt with inside the scripted model itself.
We therefore only apply these basic cleanup operations :
- remove html tags
- lowercase all text

Args:
data (str): The input data is in the form of a string

Returns:
(Tensor): Text Tensor is returned after perfoming the pre-processing operations
(str): The raw input is also returned in this function
"""

# Compat layer: normally the envelope should just return the data
# directly, but older versions of Torchserve didn't have envelope.
# Processing only the first input, not handling batch inference

line = data[0]
text = line.get("data") or line.get("body")
# Decode text if not a str but bytes or bytearray
if isinstance(text, (bytes, bytearray)):
text = text.decode("utf-8")

text = remove_html_tags(text)
text = text.lower()

return text

def inference(self, data, *args, **kwargs):
"""The Inference Request is made through this function and the user
needs to override the inference function to customize it.

Args:
data (torch tensor): The data is in the form of Torch Tensor
whose shape should match that of the
Model Input shape.

Returns:
(Torch Tensor): The predicted response from the model is returned
in this function.
"""
with torch.no_grad():
results = self.model(data)
return results

def postprocess(self, data):
"""
The post process function converts the prediction response into a
Torchserve compatible format

Args:
data (Torch Tensor): The data parameter comes from the prediction output
output_explain (None): Defaults to None.

Returns:
(list): Returns the response containing the predictions and explanations
(if the Endpoint is hit).It takes the form of a list of dictionary.
"""
data = F.softmax(data)
data = data.tolist()
return map_class_to_label(data, self.mapping)

def _load_torchscript_model(self, model_pt_path):
"""Loads the PyTorch model and returns the NN model object.

Args:
model_pt_path (str): denotes the path of the model file.

Returns:
(NN Model Object) : Loads the model object.
"""
# TODO: remove this method if https://github.com/pytorch/text/issues/1793 gets resolved
model = torch.jit.load(model_pt_path)
model.to(self.device)
return model
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"0":"Negative",
"1":"Positive"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The Rock is destined to be the 21st Century 's new `` Conan '' and that he 's going to make a splash even greater than Arnold Schwarzenegger , Jean-Claud Van Damme or Steven Segal .
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
Combine tokenizer and XLM-RoBERTa model pretrained on SST-2 Binary text classification
"""

import argparse
from typing import Any

import torch
import torchtext.functional as F
import torchtext.transforms as T
from torch import nn
from torch.hub import load_state_dict_from_url
from torchtext.models import XLMR_BASE_ENCODER, RobertaClassificationHead

PADDING_IDX = 1
BOS_IDX = 0
EOS_IDX = 2
MAX_SEQ_LEN = 256
# Vocab file for the pretrained XLM-RoBERTa model
XLMR_VOCAB_PATH = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
# Model file for ther pretrained SentencePiece tokenizer
XLMR_SPM_MODEL_PATH = (
r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
)


class TokenizerModelAdapter(nn.Module):
"""
TokenizerModelAdapter moves input onto device and adds batch dimension
"""

def __init__(self, padding_idx):
super().__init__()
self._padding_idx = padding_idx
self._dummy_param = nn.Parameter(torch.empty(0))

def forward(self, tokens: Any) -> torch.Tensor:
"""
Moves input onto device and adds batch dimension.

Args:
x (Any): Tokenizer output. As we script the combined model, we need to
hint the type of the input argument of the adapter module which TorchScript
identified as Any. Chosing a more restrictive type lets the scripting fail.

Returns:
(Tensor): On device text tensor with batch dimension
"""
tokens = F.to_tensor(tokens, padding_value=self._padding_idx).to(
self._dummy_param.device
)
# If a single sample is tokenized we need to add the batch dimension
if len(tokens.shape) < 2:
return tokens.unsqueeze(0)
return tokens


def main(args):

# Chain preprocessing steps as defined during training.
text_transform = T.Sequential(
T.SentencePieceTokenizer(XLMR_SPM_MODEL_PATH),
T.VocabTransform(load_state_dict_from_url(XLMR_VOCAB_PATH)),
T.Truncate(MAX_SEQ_LEN - 2),
T.AddToken(token=BOS_IDX, begin=True),
T.AddToken(token=EOS_IDX, begin=False),
)

NUM_CLASSES = 2
INPUT_DIM = 768

classifier_head = RobertaClassificationHead(
num_classes=NUM_CLASSES, input_dim=INPUT_DIM
)

model = XLMR_BASE_ENCODER.get_model(head=classifier_head)

# Load trained parameters and load them into the model
model.load_state_dict(torch.load(args.input_file))

# Chain the tokenizer, the adapter and the model
combi_model = T.Sequential(
text_transform,
TokenizerModelAdapter(PADDING_IDX),
model,
)

combi_model.eval()

# Make sure to move the model to CPU to avoid placement error during loading
combi_model.to("cpu")

combi_model_jit = torch.jit.script(combi_model)

torch.jit.save(combi_model_jit, args.output_file)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Combine tokenzier and model.")
parser.add_argument("input_file", type=str)
parser.add_argument("output_file", type=str)

args = parser.parse_args()
main(args)
36 changes: 23 additions & 13 deletions model-archiver/model_archiver/model_packaging.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@

"""
Command line interface to export model files to be used for inference by MXNet Model Server
"""

import logging
import sys
import shutil
from .arg_parser import ArgParser
from .model_packaging_utils import ModelExportUtils
from .model_archiver_error import ModelArchiverError
import sys

from model_archiver.arg_parser import ArgParser
from model_archiver.model_archiver_error import ModelArchiverError
from model_archiver.model_packaging_utils import ModelExportUtils


def package_model(args, manifest):
Expand All @@ -26,19 +26,29 @@ def package_model(args, manifest):
try:
ModelExportUtils.validate_inputs(model_name, export_file_path)
# Step 1 : Check if .mar already exists with the given model name
export_file_path = ModelExportUtils.check_mar_already_exists(model_name, export_file_path,
args.force, args.archive_format)
export_file_path = ModelExportUtils.check_mar_already_exists(
model_name, export_file_path, args.force, args.archive_format
)

# Step 2 : Copy all artifacts to temp directory
artifact_files = {'model_file': model_file, 'serialized_file': serialized_file, 'handler': handler,
'extra_files': extra_files, 'requirements-file': requirements_file}
artifact_files = {
"model_file": model_file,
"serialized_file": serialized_file,
"handler": handler,
"extra_files": extra_files,
"requirements-file": requirements_file,
}

model_path = ModelExportUtils.copy_artifacts(model_name, **artifact_files)

# Step 2 : Zip 'em all up
ModelExportUtils.archive(export_file_path, model_name, model_path, manifest, args.archive_format)
ModelExportUtils.archive(
export_file_path, model_name, model_path, manifest, args.archive_format
)
shutil.rmtree(model_path)
logging.info("Successfully exported model %s to file %s", model_name, export_file_path)
logging.info(
"Successfully exported model %s to file %s", model_name, export_file_path
)
except ModelArchiverError as e:
logging.error(e)
sys.exit(1)
Expand All @@ -50,11 +60,11 @@ def generate_model_archive():
:return:
"""

logging.basicConfig(format='%(levelname)s - %(message)s')
logging.basicConfig(format="%(levelname)s - %(message)s")
args = ArgParser.export_model_args_parser().parse_args()
manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest=manifest)


if __name__ == '__main__':
if __name__ == "__main__":
generate_model_archive()
Loading