Skip to content

Commit

Permalink
Launcher core pinning (#1401)
Browse files Browse the repository at this point in the history
* added core pinning feature

* added comment on instance_idx improving perf

* (1) removed ninstances, instance_idx that's internal to not confuse readers (2) updated socket_id to node_id (3) added launcher core pinning section

* commented restartWorkers logic

* minor grammar change

* added scaling workers

* fixed java formatting

* updated REAMDE

* updated REAME

* added int8 links

* fix to Intel® Extension for PyTorch*

* added link to int8

* added steps to creating bert_ipex_int8.mar

* added core pinning ut

* include (1) core pinning (2) core pinning with worker scaling gifs

* add worker scaling footnote

* simplify isLauncherRestartWorkers

Co-authored-by: Aaqib <maaquib@gmail.com>

* lint with pre-commit

* dummy commit for ci

* undo dummy commit for ci

* fix minor shift

* check launcher available

* add pytest checkif windows

* added ipex to requirements/deverloper.txt

* check if ipex installed

* fixed mistake

* updated to ipex launcher available

* change naming

* change dummy cmd to ls

Co-authored-by: min-jean-cho <minjeanc@mlp-prod-skx-7825.ra.intel.com>
Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
Co-authored-by: Aaqib <maaquib@gmail.com>
  • Loading branch information
4 people committed Jul 13, 2022
1 parent 42c3225 commit 16d5e2b
Show file tree
Hide file tree
Showing 6 changed files with 370 additions and 15 deletions.
143 changes: 130 additions & 13 deletions examples/intel_extension_for_pytorch/README.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
# TorchServe with Intel® Extension for PyTorch*

TorchServe can be used with Intel® Extension for PyTorch* (IPEX) to give performance boost on Intel hardware<sup>1</sup>.
TorchServe can be used with Intel® Extension for PyTorch* (IPEX) to give performance boost on Intel hardware.<sup>1</sup>
Here we show how to use TorchServe with IPEX.

<sup>1. While IPEX benefits all platforms, plaforms with AVX512 benefit the most. </sup>
<sup>1. While IPEX benefits all platforms, platforms with AVX512 benefit the most. </sup>

## Contents of this Document
* [Install Intel Extension for PyTorch](#install-intel-extension-for-pytorch)
* [Serving model with Intel Extension for PyTorch](#serving-model-with-intel-extension-for-pytorch)
* [Install Intel® Extension for PyTorch*](#install-intel-extension-for-pytorch)
* [Serving model with Intel® Extension for PyTorch*](#serving-model-with-intel-extension-for-pytorch)
* [TorchServe with Launcher](#torchserve-with-launcher)
* [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex)
* [Benchmarking with Launcher](#benchmarking-with-launcher)
* [Performance Boost with IPEX and Launcher](#performance-boost-with-ipex-and-launcher)


## Install Intel Extension for PyTorch
## Install Intel® Extension for PyTorch*
Refer to the documentation [here](https://github.com/intel/intel-extension-for-pytorch#installation).

## Serving model with Intel Extension for PyTorch
## Serving model with Intel® Extension for PyTorch*
After installation, all it needs to be done to use TorchServe with IPEX is to enable it in `config.properties`.
```
ipex_enable=true
```
Once IPEX is enabled, deploying PyTorch model follows the same procedure shown [here](https://pytorch.org/serve/use_cases.html). TorchServe with IPEX can deploy any model and do inference.

## TorchServe with Launcher
Launcher is a script to automate the process of tunining configuration setting on intel hardware to boost performance. Tuning configurations such as OMP_NUM_THREADS, thread affininty, memory allocator can have a dramatic effect on performance. Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) and [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for details on performance tuning with launcher.
Launcher is a script to automate the process of tunining configuration setting on intel hardware to boost performance. Tuning configurations such as OMP_NUM_THREADS, thread affininty, memory allocator can have a dramatic effect on performance. Please refer to [Performance Tuning Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) and [Launch Script Usage Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for details on performance tuning with launcher.

All it needs to be done to use TorchServe with launcher is to set its configuration in `config.properties`.

Expand Down Expand Up @@ -55,19 +56,42 @@ cpu_launcher_enable=true
cpu_launcher_args=--use_logical_core --disable_numactl
```

Some useful `cpu_launcher_args` to note are:
Below is some useful `cpu_launcher_args` to note. Italic values are default if applicable.
1. Memory Allocator: [ PTMalloc `--use_default_allocator` | *TCMalloc `--enable_tcmalloc`* | JeMalloc `--enable_jemalloc`]
* PyTorch by defualt uses PTMalloc. TCMalloc/JeMalloc generally gives better performance.
2. OpenMP library: [GNU OpenMP `--disable_iomp` | *Intel OpenMP*]
* PyTorch by default uses GNU OpenMP. Launcher by default uses Intel OpenMP. Intel OpenMP library generally gives better performance.
3. Socket id: [`--socket_id`]
* Launcher by default uses all physical cores. Limit memory access to local memories on the Nth socket to avoid Non-Uniform Memory Access (NUMA).
3. Node id: [`--node_id`]
* Launcher by default uses all NUMA nodes. Limit memory access to local memories on the Nth Numa node to avoid Non-Uniform Memory Access (NUMA).

Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for a full list of tunable configuration of launcher.
Please refer to [Launch Script Usage Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for a full list of tunable configuration of launcher. And please refer to [Performance Tuning Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) for more details.

### Launcher Core Pinning to Boost Performance of TorchServe Multi Worker Inference
When running [multi-worker inference](https://pytorch.org/serve/management_api.html#scale-workers) with Torchserve, launcher pin cores to workers to boost performance. Internally, launcher equally divides the number of cores by the number of workers such that each worker is pinned to assigned cores. Doing so avoids core overlap between workers which can signficantly boost performance for TorchServe multi-worker inference. For example, assume running 4 workers on a machine with Intel(R) Xeon(R) Platinum 8180 CPU, 2 sockets, 28 cores per socket, 2 threads per core. Launcher will bind worker 0 to cores 0-13, worker 1 to cores 14-27, worker 2 to cores 28-41, and worker 3 to cores 42-55.

CPU usage is shown below. 4 main worker threads were launched, each launching 14 threads affinitized to the assigned physical cores.
![26](https://user-images.githubusercontent.com/93151422/170373651-fd8a0363-febf-4528-bbae-e1ddef119358.gif)


#### Scaling workers
Additionally when dynamically [scaling the number of workers](https://pytorch.org/serve/management_api.html#scale-workers), cores that were pinned to killed workers by the launcher could be left unutilized. To address this problem, launcher internally restarts the workers to re-distribute cores that were pinned to killed workers to the remaining, alive workers. This is taken care internally, so users do not have to worry about this.

Continuing with the above example with 4 workers, assume killing workers 2 and 3. If cores were not re-distributed after the scale down, cores 28-55 would be left unutilized. Instead, launcher re-distributes cores 28-55 to workers 0 and 1 such that now worker 0 binds to cores 0-27 and worker 1 binds to cores 28-55.<sup>2</sup>

CPU usage is shown below. 4 main worker threads were initially launched. Then after scaling down the number of workers from 4 to 2, 2 main worker threads were launched, each launching 28 threads affinitized to the assigned physical cores.
![worker_scaling](https://user-images.githubusercontent.com/93151422/170374697-7497c2d5-4c17-421b-9993-1434d1f722f6.gif)

<sup>2. Serving is interrupted for few seconds while re-distributing cores to scaled workers.</sup>

Again, all it needs to be done to use TorchServe with launcher core pinning for multiple workers as well as scaling workers is to set its configuration in `config.properties`.

Add the following lines in `config.properties` to use launcher with its default configuration.
```
cpu_launcher_enable=true
```

## Creating and Exporting INT8 model for IPEX
Intel Extension for PyTorch supports both eager and torchscript mode. In this section, we show how to deploy INT8 model for IPEX.
Intel® Extension for PyTorch* supports both eager and torchscript mode. In this section, we show how to deploy INT8 model for IPEX. Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/features/int8.md) for more details on Intel® Extension for PyTorch* optimizations for quantization.

### 1. Creating a serialized file
First create `.pt` serialized file using IPEX INT8 inference. Here we show two examples with BERT and ResNet50.
Expand Down Expand Up @@ -167,10 +191,18 @@ torch.jit.save(model, 'rn50_int8_jit.pt')
```

### 2. Creating a Model Archive
Once the serialized file ( `.pt`) is created, it can be used with `torch-model-archiver` as ususal. Use the following command to package the model.
Once the serialized file ( `.pt`) is created, it can be used with `torch-model-archiver` as ususal.

Use the following command to package `rn50_int8_jit.pt` into `rn50_ipex_int8.mar`.
```
torch-model-archiver --model-name rn50_ipex_int8 --version 1.0 --serialized-file rn50_int8_jit.pt --handler image_classifier
```
Similarly, use the following command in the [Huggingface_Transformers directory](https://github.com/pytorch/serve/tree/master/examples/Huggingface_Transformers) to package `bert_int8_jit.pt` into `bert_ipex_int8.mar`.

```
torch-model-archiver --model-name bert_ipex_int8 --version 1.0 --serialized-file bert_int8_jit.pt --handler ./Transformer_handler_generalized.py --extra-files "./setup_config.json,./Seq_classification_artifacts/index_to_name.json"
```

### 3. Start TorchServe to serve the model
Make sure to set `ipex_enable=true` in `config.properties`. Use the following command to start TorchServe with IPEX.
```
Expand Down Expand Up @@ -223,3 +255,88 @@ $ cat logs/model_log.log
2021-12-02 06:15:03,982 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so
```

### Benchmarking with Launcher Core Pinning
As described previously in [TorchServe with Launcher](#torchserve-with-launcher), launcher core pinning boosts performance of multi-worker inference. We'll demonstrate launcher core pinning with TorchServe benchmark, but keep in mind that launcher core pinning is a generic feature applicable to any TorchServe multi-worker inference use casese.

For example, assume running 4 workers
```
python benchmark-ab.py --workers 4
```
on a machine with Intel(R) Xeon(R) Platinum 8180 CPU, 2 sockets, 28 cores per socket, 2 threads per core. Launcher will bind worker 0 to cores 0-13, worker 1 to cores 14-27, worker 2 to cores 28-41, and worker 3 to cores 42-55.

All it needs to be done to use TorchServe with launcher's core pinning is to enable launcher in `config.properties`.

Add the following lines to `config.properties` in the benchmark directory to use launcher's core pinning:
```
cpu_launcher_enable=true
```

CPU usage is shown as below:
![launcher_core_pinning](https://user-images.githubusercontent.com/93151422/159063975-e7e8d4b0-e083-4733-bdb6-4d92bdc10556.gif)

4 main worker threads were launched, then each launched a num_physical_cores/num_workers number (14) of threads affinitized to the assigned physical cores.

<pre><code>
$ cat logs/model_log.log
2022-03-24 10:41:32,223 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:41:32,223 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:41:32,223 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:41:32,223 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:41:32,223 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:41:32,223 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so:<VIRTUAL_ENV>/lib/libtcmalloc.so
2022-03-24 10:41:32,223 - __main__ - INFO - <b>numactl -C 0-13 -m 0</b> <VIRTUAL_ENV>/bin/python -u <VIRTUAL_ENV>/lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9000

2022-03-24 10:49:03,760 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:49:03,761 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:49:03,762 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:49:03,762 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:49:03,762 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:49:03,762 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so:<VIRTUAL_ENV>/lib/libtcmalloc.so
2022-03-24 10:49:03,763 - __main__ - INFO - <b>numactl -C 14-27 -m 0</b> <VIRTUAL_ENV>/bin/python -u <VIRTUAL_ENV>/lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9001

2022-03-24 10:49:26,274 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:49:26,274 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:49:26,274 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:49:26,274 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:49:26,274 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:49:26,274 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so:<VIRTUAL_ENV>/lib/libtcmalloc.so
2022-03-24 10:49:26,274 - __main__ - INFO - <b>numactl -C 28-41 -m 1</b> <VIRTUAL_ENV>/bin/python -u <VIRTUAL_ENV>/lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9002

2022-03-24 10:49:42,975 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:49:42,975 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:49:42,975 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:49:42,975 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:49:42,975 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:49:42,975 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so:<VIRTUAL_ENV>/lib/libtcmalloc.so
2022-03-24 10:49:42,975 - __main__ - INFO - <b>numactl -C 42-55 -m 1</b> <VIRTUAL_ENV>/bin/python -u <VIRTUAL_ENV>/lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9003
</code></pre>

## Performance Boost with IPEX and Launcher

![pdt_perf](https://user-images.githubusercontent.com/93151422/159067306-dfd604e3-8c66-4365-91ae-c99f68d972d5.png)


Above shows performance improvement of Torchserve with IPEX and launcher on ResNet50 and BERT-base-uncased. Torchserve official [apache-bench benchmark](https://github.com/pytorch/serve/tree/master/benchmarks#benchmarking-with-apache-bench) on Amazon EC2 m6i.24xlarge was used to collect the results<sup>2</sup>. Add the following lines in ```config.properties``` to reproduce the results. Notice that launcher is configured such that a single instance uses all physical cores on a single socket to avoid cross socket communication and core overlap.

```
ipex_enable=true
cpu_launcher_enable=true
cpu_launcher_args=--node_id 0 --enable_jemalloc
```
Use the following command to reproduce the results.
```
python benchmark-ab.py --url {modelUrl} --input {inputPath} --concurrency 1
```

For example, run the following command to reproduce latency performance of ResNet50 with data type of IPEX int8 and batch size of 1. Please refer to [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex) for steps to creating ```rn50_ipex_int8.mar``` file for ResNet50 with IPEX int8 data type.
```
python benchmark-ab.py --url 'file:///model_store/rn50_ipex_int8.mar' --concurrency 1
```

For example, run the following command to reproduce latency performance of BERT with data type of IPEX int8 and batch size of 1. Please refer to [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex) for steps to creating ```bert_ipex_int8.mar``` file for BERT with IPEX int8 data type.
```
python benchmark-ab.py --url 'file:///model_store/bert_ipex_int8.mar' --input '../examples/Huggingface_Transformers/Seq_classification_artifacts/sample_text_captum_input.txt' --concurrency 1
```

<sup>3. Amazon EC2 m6i.24xlarge was used for benchmarking purpose only. For multi-core instances, ipex optimizations automatically scale and leverage full instance resources.</sup>
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,26 @@ public int getNumRunningWorkers(ModelVersionName modelVersionName) {
return numWorking;
}

/**
* Checks if cpu_launcher is enabled and currentWorkers > 0 (i.e., not initializing workers).
* Workers are restarted so that when dynamically scaling the number of workers, cores that were
* pinned to killed workers by the launcher are not left unutilizied. If isRestart, workers are
* restarted to re-distribute cores that were pinned to killed workers to the remaining, alive
* workers.
*/
public boolean isLauncherRestartWorkers(int currentWorkers) {
return configManager.isCPULauncherEnabled() && currentWorkers > 0;
}

public CompletableFuture<Integer> modelChanged(
Model model, boolean isStartup, boolean isCleanUp) {
synchronized (model.getModelVersionName()) {
boolean isSnapshotSaved = false;
CompletableFuture<Integer> future = new CompletableFuture<>();
int minWorker = model.getMinWorkers();
int maxWorker = model.getMaxWorkers();
// Sets restartNumWorkers to the updated minWorker after scale up/down
int restartNumWorkers = minWorker;
List<WorkerThread> threads;
if (minWorker == 0) {
threads = workers.remove(model.getModelVersionName());
Expand All @@ -109,6 +122,18 @@ public CompletableFuture<Integer> modelChanged(
}

int currentWorkers = threads.size();
boolean isRestartWorkers = isLauncherRestartWorkers(currentWorkers);

if (isRestartWorkers) {
logger.warn(
"removing {} current thread(s) prior to restarting {} thread(s)",
currentWorkers,
minWorker);
// By setting maxWorker and minWorker to 0, removes all currentWorkers
maxWorker = 0;
minWorker = 0;
}

if (currentWorkers < minWorker) {
addThreads(threads, model, minWorker - currentWorkers, future);
} else {
Expand Down Expand Up @@ -150,6 +175,13 @@ public CompletableFuture<Integer> modelChanged(
}
future.complete(HttpURLConnection.HTTP_OK);
}

// After removing all currentWorkers, add back (i.e., restart) restartNumWorkers
if (isRestartWorkers) {
logger.warn("restarting {} thread(s)", restartNumWorkers);
addThreads(threads, model, restartNumWorkers, future);
}

if (!isStartup && !isSnapshotSaved && !isCleanUp && !model.isWorkflowModel()) {
SnapshotManager.getInstance().saveSnapshot();
}
Expand Down
Loading

0 comments on commit 16d5e2b

Please sign in to comment.