Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Launcher core pinning #1401

Merged
merged 38 commits into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
52bdf66
added core pinning feature
Jan 12, 2022
9ffaaf5
added comment on instance_idx improving perf
min-jean-cho Jan 13, 2022
d95e7aa
Merge branch 'pytorch:master' into launcher_core_pinning
min-jean-cho Mar 18, 2022
3417c90
(1) removed ninstances, instance_idx that's internal to not confuse r…
min-jean-cho Mar 18, 2022
d8b5cce
Merge branch 'master' into launcher_core_pinning
msaroufim Mar 23, 2022
bcfcfa8
commented restartWorkers logic
Mar 23, 2022
c78f8ce
minor grammar change
Mar 23, 2022
9eb94b6
added scaling workers
min-jean-cho Mar 23, 2022
4620e08
fixed java formatting
Mar 24, 2022
ac01674
updated REAMDE
min-jean-cho Mar 24, 2022
9d0c035
updated REAME
min-jean-cho Mar 24, 2022
3901593
Merge branch 'pytorch:master' into launcher_core_pinning
min-jean-cho Apr 4, 2022
3d86c01
added int8 links
min-jean-cho Apr 11, 2022
7ebdc91
fix to Intel® Extension for PyTorch*
min-jean-cho Apr 11, 2022
d34bd5c
added link to int8
min-jean-cho Apr 11, 2022
290dac3
added steps to creating bert_ipex_int8.mar
min-jean-cho Apr 11, 2022
638202e
Merge branch 'pytorch:master' into launcher_core_pinning
min-jean-cho Apr 18, 2022
4893885
Merge branch 'master' into launcher_core_pinning
min-jean-cho Apr 26, 2022
4594f98
Merge branch 'master' into launcher_core_pinning
min-jean-cho Apr 28, 2022
5d07cb2
added core pinning ut
Apr 29, 2022
f99d54f
Merge branch 'master' into launcher_core_pinning
msaroufim May 5, 2022
7c0af9d
Merge branch 'pytorch:master' into launcher_core_pinning
min-jean-cho May 25, 2022
23385a7
include (1) core pinning (2) core pinning with worker scaling gifs
min-jean-cho May 25, 2022
c848f16
add worker scaling footnote
min-jean-cho May 25, 2022
24ae718
Merge branch 'master' into launcher_core_pinning
min-jean-cho Jun 29, 2022
72c1ca9
simplify isLauncherRestartWorkers
min-jean-cho Jun 30, 2022
ef850d3
lint with pre-commit
Jun 30, 2022
a123334
dummy commit for ci
Jun 30, 2022
d06cc8c
undo dummy commit for ci
Jun 30, 2022
d7f82a0
fix minor shift
Jul 5, 2022
57d34cd
check launcher available
Jul 5, 2022
574653c
add pytest checkif windows
Jul 6, 2022
cb489ee
added ipex to requirements/deverloper.txt
Jul 6, 2022
1a7fd9a
check if ipex installed
Jul 6, 2022
86a160c
fixed mistake
Jul 6, 2022
50eaaa0
updated to ipex launcher available
Jul 6, 2022
4e484b5
change naming
Jul 6, 2022
2827603
change dummy cmd to ls
Jul 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 107 additions & 6 deletions examples/intel_extension_for_pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
TorchServe can be used with Intel® Extension for PyTorch* (IPEX) to give performance boost on Intel hardware<sup>1</sup>.
Here we show how to use TorchServe with IPEX.

<sup>1. While IPEX benefits all platforms, plaforms with AVX512 benefit the most. </sup>
<sup>1. While IPEX benefits all platforms, platforms with AVX512 benefit the most. </sup>

## Contents of this Document
* [Install Intel Extension for PyTorch](#install-intel-extension-for-pytorch)
* [Serving model with Intel Extension for PyTorch](#serving-model-with-intel-extension-for-pytorch)
* [TorchServe with Launcher](#torchserve-with-launcher)
* [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex)
* [Benchmarking with Launcher](#benchmarking-with-launcher)
* [Performance Boost with IPEX and Launcher](#performance-boost-with-ipex-and-launcher)


## Install Intel Extension for PyTorch
Expand All @@ -24,7 +25,7 @@ ipex_enable=true
Once IPEX is enabled, deploying PyTorch model follows the same procedure shown [here](https://pytorch.org/serve/use_cases.html). TorchServe with IPEX can deploy any model and do inference.

## TorchServe with Launcher
Launcher is a script to automate the process of tunining configuration setting on intel hardware to boost performance. Tuning configurations such as OMP_NUM_THREADS, thread affininty, memory allocator can have a dramatic effect on performance. Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) and [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for details on performance tuning with launcher.
Launcher is a script to automate the process of tunining configuration setting on intel hardware to boost performance. Tuning configurations such as OMP_NUM_THREADS, thread affininty, memory allocator can have a dramatic effect on performance. Please refer to [Performance Tuning Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) and [Launch Script Usage Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for details on performance tuning with launcher.

All it needs to be done to use TorchServe with launcher is to set its configuration in `config.properties`.

Expand Down Expand Up @@ -55,16 +56,31 @@ cpu_launcher_enable=true
cpu_launcher_args=--use_logical_core --disable_numactl
```

Some useful `cpu_launcher_args` to note are:
Below is some useful `cpu_launcher_args` to note. Italic values are default if applicable.
1. Memory Allocator: [ PTMalloc `--use_default_allocator` | *TCMalloc `--enable_tcmalloc`* | JeMalloc `--enable_jemalloc`]
* PyTorch by defualt uses PTMalloc. TCMalloc/JeMalloc generally gives better performance.
2. OpenMP library: [GNU OpenMP `--disable_iomp` | *Intel OpenMP*]
* PyTorch by default uses GNU OpenMP. Launcher by default uses Intel OpenMP. Intel OpenMP library generally gives better performance.
3. Socket id: [`--socket_id`]
* Launcher by default uses all physical cores. Limit memory access to local memories on the Nth socket to avoid Non-Uniform Memory Access (NUMA).
3. Node id: [`--node_id`]
* Launcher by default uses all NUMA nodes. Limit memory access to local memories on the Nth Numa node to avoid Non-Uniform Memory Access (NUMA).

Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for a full list of tunable configuration of launcher.
Please refer to [Launch Script Usage Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for a full list of tunable configuration of launcher. And please refer to [Performance Tuning Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) for more details.

### Launcher Core Pinning to Boost Performance of TorchServe Multi Worker Inference
When running [multi-worker inference](https://pytorch.org/serve/management_api.html#scale-workers) with Torchserve, launcher pin cores to workers to boost performance. Internally, launcher equally divides the number of cores by the number of workers such that each worker is pinned to assigned cores. Doing so avoids core overlap between workers which can signficantly boost performance for TorchServe multi-worker inference. For example, assume running 4 workers on a machine with Intel(R) Xeon(R) Platinum 8180 CPU, 2 sockets, 28 cores per socket, 2 threads per core. Launcher will bind worker 0 to cores 0-13, worker 1 to cores 14-27, worker 2 to cores 28-41, and worker 3 to cores 42-55.

#### Scaling workers
Additionally when dynamically [scaling the number of workers](https://pytorch.org/serve/management_api.html#scale-workers), cores that were pinned to killed workers by the launcher could be left unutilized. To address this problem, launcher internally restarts the workers to re-distribute cores that were pinned to killed workers to the remaining, alive workers. This is taken care internally, so users do not have to worry about this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be helpful to test out if this works in a docker image as well

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we still want to test out in a docker image? Please let me know. Thanks. cc @msaroufim @lxning

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes because need to make sure that the current docker image has access to the same environment variables

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

launcher internally restarts the workers to re-distribute cores that were pinned to killed workers to the remaining.
@min-jean-cho I was wondering if this means serving will be interrupted for the time this re-distribution is taking place?

Continuing with the above example with 4 workers, assume killing workers 2 and 3. If cores were not re-distributed after the scale down, cores 28-55 would be left unutilized. Instead, launcher re-distributes cores 28-55 to workers 0 and 1 such that now worker 0 binds to cores 0-27 and worker 1 binds to cores 28-55.


Again, all it needs to be done to use TorchServe with launcher core pinning for multiple workers as well as scaling workers is to set its configuration in `config.properties`.

Add the following lines in `config.properties` to use launcher with its default configuration.
```
cpu_launcher_enable=true
```

## Creating and Exporting INT8 model for IPEX
Intel Extension for PyTorch supports both eager and torchscript mode. In this section, we show how to deploy INT8 model for IPEX.
Expand Down Expand Up @@ -223,3 +239,88 @@ $ cat logs/model_log.log
2021-12-02 06:15:03,982 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so

```

### Benchmarking with Launcher Core Pinning
As described previously in [TorchServe with Launcher](#torchserve-with-launcher), launcher core pinning boosts performance of multi-worker inference. We'll demonstrate launcher core pinning with TorchServe benchmark, but keep in mind that launcher core pinning is a generic feature applicable to any TorchServe multi-worker inference use casese.

For example, assume running 4 workers
```
python benchmark-ab.py --workers 4
```
on a machine with Intel(R) Xeon(R) Platinum 8180 CPU, 2 sockets, 28 cores per socket, 2 threads per core. Launcher will bind worker 0 to cores 0-13, worker 1 to cores 14-27, worker 2 to cores 28-41, and worker 3 to cores 42-55.

All it needs to be done to use TorchServe with launcher's core pinning is to enable launcher in `config.properties`.

Add the following lines to `config.properties` in the benchmark directory to use launcher's core pinning:
```
cpu_launcher_enable=true
```

CPU usage is shown as below:
![launcher_core_pinning](https://user-images.githubusercontent.com/93151422/159063975-e7e8d4b0-e083-4733-bdb6-4d92bdc10556.gif)

4 main worker threads were launched, then each launched a num_physical_cores/num_workers number (14) of threads affinitized to the assigned physical cores.

<pre><code>
$ cat logs/model_log.log
2022-03-24 10:41:32,223 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:41:32,223 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:41:32,223 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:41:32,223 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:41:32,223 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:41:32,223 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so:<VIRTUAL_ENV>/lib/libtcmalloc.so
2022-03-24 10:41:32,223 - __main__ - INFO - <b>numactl -C 0-13 -m 0</b> <VIRTUAL_ENV>/bin/python -u <VIRTUAL_ENV>/lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9000

2022-03-24 10:49:03,760 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:49:03,761 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:49:03,762 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:49:03,762 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:49:03,762 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:49:03,762 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so:<VIRTUAL_ENV>/lib/libtcmalloc.so
2022-03-24 10:49:03,763 - __main__ - INFO - <b>numactl -C 14-27 -m 0</b> <VIRTUAL_ENV>/bin/python -u <VIRTUAL_ENV>/lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9001

2022-03-24 10:49:26,274 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:49:26,274 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:49:26,274 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:49:26,274 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:49:26,274 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:49:26,274 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so:<VIRTUAL_ENV>/lib/libtcmalloc.so
2022-03-24 10:49:26,274 - __main__ - INFO - <b>numactl -C 28-41 -m 1</b> <VIRTUAL_ENV>/bin/python -u <VIRTUAL_ENV>/lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9002

2022-03-24 10:49:42,975 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:49:42,975 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:49:42,975 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:49:42,975 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:49:42,975 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:49:42,975 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so:<VIRTUAL_ENV>/lib/libtcmalloc.so
2022-03-24 10:49:42,975 - __main__ - INFO - <b>numactl -C 42-55 -m 1</b> <VIRTUAL_ENV>/bin/python -u <VIRTUAL_ENV>/lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9003
</code></pre>

## Performance Boost with IPEX and Launcher

![pdt_perf](https://user-images.githubusercontent.com/93151422/159067306-dfd604e3-8c66-4365-91ae-c99f68d972d5.png)


Above shows performance improvement of Torchserve with IPEX and launcher on ResNet50 and BERT-base-uncased. Torchserve official [apache-bench benchmark](https://github.com/pytorch/serve/tree/master/benchmarks#benchmarking-with-apache-bench) on Amazon EC2 m6i.24xlarge was used to collect the results<sup>2</sup>. Add the following lines in ```config.properties``` to reproduce the results. Notice that launcher is configured such that a single instance uses all physical cores on a single socket to avoid cross socket communication and core overlap.

```
ipex_enable=true
cpu_launcher_enable=true
cpu_launcher_args=--node_id 0 --enable_jemalloc
```
Use the following command to reproduce the results.
```
python benchmark-ab.py --url {modelUrl} --input {inputPath} --concurrency 1
```

For example, run the following command to reproduce latency performance of ResNet50 with data type of IPEX int8 and batch size of 1.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will be good to mention/ add a link to how int8 quantization was applied to these models.

Copy link
Collaborator Author

@min-jean-cho min-jean-cho Apr 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @HamidShojanazeri , that's a good question - serving is indeed paused/interrupted for few seconds while redistributing. But there's no code crash or anything. Let me know if this suffices or if you would like further experiments. Thanks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @min-jean-cho, I think it might be good to clarify it in the doc.
cc: @lxning

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @HamidShojanazeri , I have added a gif and footnote to demonstrate that serving is interrupted for few seconds while re distributing cores. Please have a look at the updated README https://github.com/min-jean-cho/serve/tree/launcher_core_pinning/examples/intel_extension_for_pytorch#scaling-workers . Thanks

```
python benchmark-ab.py --url 'file:///model_store/rn50_ipex_int8.mar' --concurrency 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would need to add more detail on how to get this mar file? Or put a link to something like an S3 bucket or gdrive

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I didn't make it clear that the rn50_ipex_int8.mar is created from steps shown here: https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/torchserve.md#2-creating-a-model-archive

```

For example, run the following command to reproduce latency performance of BERT with data type of IPEX int8 and batch size of 1.
```
python benchmark-ab.py --url 'file:///model_store/bert_ipex_int8.mar' --input '../examples/Huggingface_Transformers/Seq_classification_artifacts/sample_text_captum_input.txt' --concurrency 1
```

<sup>2. Amazon EC2 m6i.24xlarge was used for benchmarking purpose only. For multi-core instances, ipex optimizations automatically scale and leverage full instance resources.</sup>
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,32 @@ public int getNumRunningWorkers(ModelVersionName modelVersionName) {
return numWorking;
}

/**
* Checks if cpu_launcher is enabled and currentWorkers > 0 (i.e., not initializing workers).
* Workers are restarted so that when dynamically scaling the number of workers, cores that were
* pinned to killed workers by the launcher are not left unutilizied. If isRestart, workers are
* restarted to re-distribute cores that were pinned to killed workers to the remaining, alive
* workers.
*/
public boolean isLauncherRestartWorkers(int currentWorkers) {
boolean isRestart;
if (configManager.isCPULauncherEnabled() && currentWorkers > 0) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checks cpu_launcher is enabled and currentWorkers is greater than 0 (i.e., only restarts workers when scaling workers up/down; if currentWorkers==0, then workers are being initialized)

isRestart = true;
} else {
isRestart = false;
}
return isRestart;
min-jean-cho marked this conversation as resolved.
Show resolved Hide resolved
}

public CompletableFuture<Integer> modelChanged(
Model model, boolean isStartup, boolean isCleanUp) {
synchronized (model.getModelVersionName()) {
boolean isSnapshotSaved = false;
CompletableFuture<Integer> future = new CompletableFuture<>();
int minWorker = model.getMinWorkers();
int maxWorker = model.getMaxWorkers();
// Sets restartNumWorkers to the updated minWorker after scale up/down
int restartNumWorkers = minWorker;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sets restartNumWorkers to updated minWorker after scale up/down.

List<WorkerThread> threads;
if (minWorker == 0) {
threads = workers.remove(model.getModelVersionName());
Expand All @@ -109,6 +128,18 @@ public CompletableFuture<Integer> modelChanged(
}

int currentWorkers = threads.size();
boolean isRestartWorkers = isLauncherRestartWorkers(currentWorkers);

if (isRestartWorkers) {
logger.warn(
"removing {} current thread(s) prior to restarting {} thread(s)",
currentWorkers,
minWorker);
// By setting maxWorker and minWorker to 0, removes all currentWorkers
maxWorker = 0;
minWorker = 0;
Copy link
Collaborator Author

@min-jean-cho min-jean-cho Jan 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By setting minWorker=0 when currentWorkers > 0, this will never enter the if (currentWorkers < minWorker) condition in line 134. And by setting maxWorker=0, the for loop in line 137 will proceed to remove all currentWorkers.

}

if (currentWorkers < minWorker) {
addThreads(threads, model, minWorker - currentWorkers, future);
} else {
Expand Down Expand Up @@ -150,6 +181,13 @@ public CompletableFuture<Integer> modelChanged(
}
future.complete(HttpURLConnection.HTTP_OK);
}

// After removing all currentWorkers, add back (i.e., restart) restartNumWorkers
if (isRestartWorkers) {
logger.warn("restarting {} thread(s)", restartNumWorkers);
addThreads(threads, model, restartNumWorkers, future);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain a bit more about the overall logic of this PR - it seems to do (my very rough understanding)

  1. Check if CPU launcher available and if so 0 out number of workers
  2. Then add threads back to model set to the old min worker
  3. Finally when a new process is launched make sure it includes the important CLI arguments like n-instance and instance-idx

Copy link
Collaborator Author

@min-jean-cho min-jean-cho Jan 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workers are restarted only when cpu_launcher is enabled (does not affect Torchserve when cpu_launcher_is not enabled). Restarting workers is useful/needed when user scales the number of workers up/down during execution ( https://pytorch.org/serve/management_api.html#scale-workers ) By restarting workers, launcher 1) re-allocates the cores that were pinned to killed workers in case of scale down; 2) avoids core overlap in case of scale up. It kills all existing workers and restarts scaled up/down number of workers. By doing so, launcher is configured properly when user scales the number of workers during execution. I will add comments to the file to explain the logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally after removing all currentNumWorkers of workers, add back (i.e., restart) restartNumWorkers. Feel free to have a try with apache benchmark. For example, setting initial workers to 1 and scaling up to 5 via curl -v -X PUT "http://localhost:8081/models/benchmark?min_worker=5" or setting initial workers to 5 and scaling down to 1 curl -v -X PUT "http://localhost:8081/models/benchmark?min_worker=1"

}

if (!isStartup && !isSnapshotSaved && !isCleanUp && !model.isWorkflowModel()) {
SnapshotManager.getInstance().saveSnapshot();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class WorkerLifeCycle {
private static final Logger logger = LoggerFactory.getLogger(WorkerLifeCycle.class);

private ConfigManager configManager;
private ModelManager modelManager = ModelManager.getInstance();
private Model model;
private int pid = -1;
private Process process;
Expand All @@ -30,10 +31,14 @@ public class WorkerLifeCycle {
private ReaderThread errReader;
private ReaderThread outReader;
private String launcherArgs;
private int numWorker;
private int currNumRunningWorkers;

public WorkerLifeCycle(ConfigManager configManager, Model model) {
this.configManager = configManager;
this.model = model;
this.numWorker = model.getMinWorkers();
this.currNumRunningWorkers = modelManager.getNumRunningWorkers(model.getModelVersionName());
}

public Process getProcess() {
Expand All @@ -44,8 +49,6 @@ public ArrayList<String> launcherArgsToList() {
ArrayList<String> arrlist = new ArrayList<String>();
arrlist.add("-m");
arrlist.add("intel_extension_for_pytorch.cpu.launch");
arrlist.add("--ninstance");
arrlist.add("1");
if (launcherArgs != null && launcherArgs.length() > 1) {
String[] argarray = launcherArgs.split(" ");
for (int i = 0; i < argarray.length; i++) {
Expand Down Expand Up @@ -99,6 +102,16 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup
if (launcherAvailable) {
ArrayList<String> args = launcherArgsToList();
argl.addAll(args);

// multi-worker core pinning
if (this.numWorker > 1) {
argl.add("--ninstances");
argl.add(String.valueOf(this.numWorker));
argl.add("--instance_idx");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A WorkerLifeCycle object is associated with one specific backend worker. I can see here each worker is assigned the same idx.

I'm not sure if each backend worker needs a specific instance idx in launcher. Could you confirm that each backend worker can use the same instance idx?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lxning , each worker is assigned currNumRunningWorkers https://github.com/min-jean-cho/serve/blob/launcher_core_pinning/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java#L41. For multiple workers, each backend worker needs a specific instance idx in the launcher - otherwise, all workers will be mapped to the same cores.

Please see below an example of ninstances and instance_idx for 4 workers.

2022-04-28T11:58:40,347 [DEBUG] W-9000-benchmark_1.0 org.pytorch.serve.wlm.WorkerLifeCycle - Worker cmdline: [/nfs/site/home/minjeanc/anaconda3/envs/py36/bin/python3, -m, intel_extension_for_pytorch.cpu.launch, --node_id, 0, --ninstances, 4, --instance_idx, 0, /ec/pdx/disks/mlp_lab_home_pool_02/minjeanc/.local/lib/python3.6/site-packages/ts/model_service_worker.py, --sock-type, unix, --sock-name, /tmp/.ts.sock.9000]
2022-04-28T11:58:40,331 [DEBUG] W-9001-benchmark_1.0 org.pytorch.serve.wlm.WorkerLifeCycle - Worker cmdline: [/nfs/site/home/minjeanc/anaconda3/envs/py36/bin/python3, -m, intel_extension_for_pytorch.cpu.launch, --node_id, 0, --ninstances, 4, --instance_idx, 1, /ec/pdx/disks/mlp_lab_home_pool_02/minjeanc/.local/lib/python3.6/site-packages/ts/model_service_worker.py, --sock-type, unix, --sock-name, /tmp/.ts.sock.9001]
2022-04-28T11:58:40,324 [DEBUG] W-9002-benchmark_1.0 org.pytorch.serve.wlm.WorkerLifeCycle - Worker cmdline: [/nfs/site/home/minjeanc/anaconda3/envs/py36/bin/python3, -m, intel_extension_for_pytorch.cpu.launch, --node_id, 0, --ninstances, 4, --instance_idx, 2, /ec/pdx/disks/mlp_lab_home_pool_02/minjeanc/.local/lib/python3.6/site-packages/ts/model_service_worker.py, --sock-type, unix, --sock-name, /tmp/.ts.sock.9002]
2022-04-28T11:58:40,344 [DEBUG] W-9003-benchmark_1.0 org.pytorch.serve.wlm.WorkerLifeCycle - Worker cmdline: [/nfs/site/home/minjeanc/anaconda3/envs/py36/bin/python3, -m, intel_extension_for_pytorch.cpu.launch, --node_id, 0, --ninstances, 4, --instance_idx, 3, /ec/pdx/disks/mlp_lab_home_pool_02/minjeanc/.local/lib/python3.6/site-packages/ts/model_service_worker.py, --sock-type, unix, --sock-name, /tmp/.ts.sock.9003]

// instance_idx is 0-indexed
argl.add(String.valueOf(this.currNumRunningWorkers));
}

} else {
logger.warn(
"CPU launcher is enabled but launcher is not available. Proceeding without launcher.");
Expand Down