Skip to content

Commit

Permalink
Merge branch 'master' into hpu_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
wozna committed Jun 7, 2024
2 parents 58f6a8e + a10aa45 commit b933070
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/performance_checklist.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ As this [example](https://colab.research.google.com/drive/1NMaLS8PG0eYhbd8IxQAaj

Start model inference optimization only after other factors, the “low-hanging fruit”, have been extensively evaluated and addressed.

- Using `with torch.inference_mode()` context before calling forward pass on your model or `@torch.inference_mode()` decorator on your `inference()` method improves inference performance. This is achieved by [disabling](https://pytorch.org/docs/stable/generated/torch.autograd.grad_mode.inference_mode.html) view tracking and version counter bumps.

- Use fp16 for GPU inference. The speed will most likely more than double on newer GPUs with tensor cores, with negligible accuracy degradation. Technically fp16 is a type of quantization but since it seldom suffers from loss of accuracy for inference it should always be explored. As shown in this [article](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#abstract), use of fp16 offers speed up in large neural network applications.

- Use model quantization (i.e. int8) for CPU inference. Explore different quantization options: dynamic quantization, static quantization, and quantization aware training, as well as tools such as Intel Neural Compressor that provide more sophisticated quantization methods. It is worth noting that quantization comes with some loss in accuracy and might not always offer significant speed up on some hardware thus this might not always be the right approach.
Expand Down
6 changes: 4 additions & 2 deletions docs/performance_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Starting with PyTorch 2.0, `torch.compile` provides out of the box speed up ( ~1

Models which have been fully optimized with `torch.compile` show performance improvements up to 10x

When using smaller batch sizes, using `mode="reduce-overhead"` with `torch.compile` can give improved performance as it makes use of CUDA graphs

You can find all the examples of `torch.compile` with TorchServe [here](https://github.com/pytorch/serve/tree/master/examples/pt2)

Details regarding `torch.compile` GenAI examples can be found in this [link](https://github.com/pytorch/serve/tree/master/examples/pt2#torchcompile-genai-examples)
Expand All @@ -30,13 +32,13 @@ At a high level what TorchServe allows you to do is

To use ONNX with GPU on TorchServe Docker, we need to build an image with [NVIDIA CUDA runtime](https://github.com/NVIDIA/nvidia-docker/wiki/CUDA) as the base image as shown [here](https://github.com/pytorch/serve/blob/master/docker/README.md#create-torchserve-docker-image)

<h4>TensorRT<h4>
<h4>TensorRT</h4>

TorchServe also supports models optimized via TensorRT. To leverage the TensorRT runtime you can convert your model by [following these instructions](https://github.com/pytorch/TensorRT) and once you're done you'll have serialized weights which you can load with [`torch.jit.load()`](https://pytorch.org/TensorRT/getting_started/getting_started_with_python_api.html#getting-started-with-python-api).

After a conversion there is no difference in how PyTorch treats a Torchscript model vs a TensorRT model.

<h4>Better Transformer<h4>
<h4>Better Transformer</h4>

Better Transformer from PyTorch implements a backwards-compatible fast path of `torch.nn.TransformerEncoder` for Transformer Encoder Inference and does not require model authors to modify their models. BetterTransformer improvements can exceed 2x in speedup and throughput for many common execution scenarios.
You can find more information on Better Transformer [here](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/) and [here](https://github.com/pytorch/serve/tree/master/examples/Huggingface_Transformers#speed-up-inference-with-better-transformer).
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Text Classification using a Scriptable Tokenizer

## Deprecation Warning!
This example requires TorchText which is deprecated. Please use version <= 0.11.0 of TorchServe for this example

TorchScript is a way to serialize and optimize your PyTorch models.
A scriptable tokenizer is a special tokenizer which is compatible with [TorchScript's compiler](https://pytorch.org/docs/stable/jit.html) so that it can be jointly serialized with a PyTorch model.
When deploying an NLP model it is important to use the same tokenizer during training and inference to achieve the same model accuracy in both phases of the model live cycle.
Expand Down
9 changes: 4 additions & 5 deletions requirements/torch_neuronx_linux.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/cpu
-r torch_common.txt
torch==1.13.1+cpu
torchvision==0.14.1+cpu
torchtext==0.14.1
torchaudio==0.13.1+cpu
torchdata==0.5.1
torch==2.1.2+cpu
torchvision==0.16.2+cpu
torchtext==0.16.2
torchaudio==2.1.2+cpu
1 change: 1 addition & 0 deletions test/pytest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
collect_ignore.append("test_example_torchrec_dlrm.py")
collect_ignore.append("test_example_near_real_time_video.py")
collect_ignore.append("test_dali_preprocess.py")
collect_ignore.append("test_example_scriptable_tokenzier.py")


@pytest.fixture(scope="module")
Expand Down
129 changes: 129 additions & 0 deletions ts/handler_utils/text_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
Functions which have been copied from TorchText to remove TorchServe's
dependency on TorchText
from torchtext.data.utils import ngrams_iterator, get_tokenizer
"""

import re


def ngrams_iterator(token_list, ngrams):
"""Return an iterator that yields the given tokens and their ngrams.
Args:
token_list: A list of tokens
ngrams: the number of ngrams.
Examples:
>>> token_list = ['here', 'we', 'are']
>>> list(ngrams_iterator(token_list, 2))
>>> ['here', 'here we', 'we', 'we are', 'are']
"""

def _get_ngrams(n):
return zip(*[token_list[i:] for i in range(n)])

for x in token_list:
yield x
for n in range(2, ngrams + 1):
for x in _get_ngrams(n):
yield " ".join(x)


_patterns = [
r"\'",
r"\"",
r"\.",
r"<br \/>",
r",",
r"\(",
r"\)",
r"\!",
r"\?",
r"\;",
r"\:",
r"\s+",
]

_replacements = [
" ' ",
"",
" . ",
" ",
" , ",
" ( ",
" ) ",
" ! ",
" ? ",
" ",
" ",
" ",
]

_patterns_dict = list((re.compile(p), r) for p, r in zip(_patterns, _replacements))


def _basic_english_normalize(line):
r"""
Basic normalization for a line of text.
Normalization includes
- lowercasing
- complete some basic text normalization for English words as follows:
add spaces before and after '\''
remove '\"',
add spaces before and after '.'
replace '<br \/>'with single space
add spaces before and after ','
add spaces before and after '('
add spaces before and after ')'
add spaces before and after '!'
add spaces before and after '?'
replace ';' with single space
replace ':' with single space
replace multiple spaces with single space
Returns a list of tokens after splitting on whitespace.
"""

line = line.lower()
for pattern_re, replaced_str in _patterns_dict:
line = pattern_re.sub(replaced_str, line)
return line.split()


def _split_tokenizer(x): # noqa: F821
return x.split()


def get_tokenizer(tokenizer, language="en"):
r"""
Generate tokenizer function for a string sentence.
Args:
tokenizer: the name of tokenizer function. If None, it returns split()
function, which splits the string sentence by space.
If basic_english, it returns _basic_english_normalize() function,
which normalize the string first and split by space. If a callable
function, it will return the function. If a tokenizer library
(e.g. spacy, moses, toktok, revtok, subword), it returns the
corresponding library.
language: Default en
Examples:
>>> tokenizer = get_tokenizer("basic_english")
>>> tokens = tokenizer("You can now install TorchText using pip!")
>>> tokens
>>> ['you', 'can', 'now', 'install', 'torchtext', 'using', 'pip', '!']
"""

# default tokenizer is string.split(), added as a module function for serialization
if tokenizer is None:
return _split_tokenizer

if tokenizer == "basic_english":
if language != "en":
raise ValueError("Basic normalization is only available for Enlish(en)")
return _basic_english_normalize
3 changes: 2 additions & 1 deletion ts/torch_handler/text_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch
import torch.nn.functional as F
from captum.attr import TokenReferenceBase
from torchtext.data.utils import ngrams_iterator

from ts.handler_utils.text_utils import ngrams_iterator

from ..utils.util import map_class_to_label
from .text_handler import TextHandler
Expand Down
3 changes: 2 additions & 1 deletion ts/torch_handler/text_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import torch
import torch.nn.functional as F
from captum.attr import LayerIntegratedGradients
from torchtext.data.utils import get_tokenizer

from ts.handler_utils.text_utils import get_tokenizer

from ..utils.util import CLEANUP_REGEX
from .base_handler import BaseHandler
Expand Down
1 change: 1 addition & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1247,3 +1247,4 @@ quant
quantizing
smoothquant
woq
TorchText

0 comments on commit b933070

Please sign in to comment.