diff --git a/examples/intel_extension_for_pytorch/README.md b/examples/intel_extension_for_pytorch/README.md index 99c1efc2ee..eb228c9d0e 100644 --- a/examples/intel_extension_for_pytorch/README.md +++ b/examples/intel_extension_for_pytorch/README.md @@ -1,22 +1,23 @@ # TorchServe with Intel® Extension for PyTorch* -TorchServe can be used with Intel® Extension for PyTorch* (IPEX) to give performance boost on Intel hardware1. +TorchServe can be used with Intel® Extension for PyTorch* (IPEX) to give performance boost on Intel hardware.1 Here we show how to use TorchServe with IPEX. -1. While IPEX benefits all platforms, plaforms with AVX512 benefit the most. +1. While IPEX benefits all platforms, platforms with AVX512 benefit the most. ## Contents of this Document -* [Install Intel Extension for PyTorch](#install-intel-extension-for-pytorch) -* [Serving model with Intel Extension for PyTorch](#serving-model-with-intel-extension-for-pytorch) +* [Install Intel® Extension for PyTorch*](#install-intel-extension-for-pytorch) +* [Serving model with Intel® Extension for PyTorch*](#serving-model-with-intel-extension-for-pytorch) * [TorchServe with Launcher](#torchserve-with-launcher) * [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex) * [Benchmarking with Launcher](#benchmarking-with-launcher) +* [Performance Boost with IPEX and Launcher](#performance-boost-with-ipex-and-launcher) -## Install Intel Extension for PyTorch +## Install Intel® Extension for PyTorch* Refer to the documentation [here](https://github.com/intel/intel-extension-for-pytorch#installation). -## Serving model with Intel Extension for PyTorch +## Serving model with Intel® Extension for PyTorch* After installation, all it needs to be done to use TorchServe with IPEX is to enable it in `config.properties`. ``` ipex_enable=true @@ -24,7 +25,7 @@ ipex_enable=true Once IPEX is enabled, deploying PyTorch model follows the same procedure shown [here](https://pytorch.org/serve/use_cases.html). TorchServe with IPEX can deploy any model and do inference. ## TorchServe with Launcher -Launcher is a script to automate the process of tunining configuration setting on intel hardware to boost performance. Tuning configurations such as OMP_NUM_THREADS, thread affininty, memory allocator can have a dramatic effect on performance. Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) and [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for details on performance tuning with launcher. +Launcher is a script to automate the process of tunining configuration setting on intel hardware to boost performance. Tuning configurations such as OMP_NUM_THREADS, thread affininty, memory allocator can have a dramatic effect on performance. Please refer to [Performance Tuning Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) and [Launch Script Usage Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for details on performance tuning with launcher. All it needs to be done to use TorchServe with launcher is to set its configuration in `config.properties`. @@ -55,19 +56,42 @@ cpu_launcher_enable=true cpu_launcher_args=--use_logical_core --disable_numactl ``` -Some useful `cpu_launcher_args` to note are: +Below is some useful `cpu_launcher_args` to note. Italic values are default if applicable. 1. Memory Allocator: [ PTMalloc `--use_default_allocator` | *TCMalloc `--enable_tcmalloc`* | JeMalloc `--enable_jemalloc`] * PyTorch by defualt uses PTMalloc. TCMalloc/JeMalloc generally gives better performance. 2. OpenMP library: [GNU OpenMP `--disable_iomp` | *Intel OpenMP*] * PyTorch by default uses GNU OpenMP. Launcher by default uses Intel OpenMP. Intel OpenMP library generally gives better performance. -3. Socket id: [`--socket_id`] - * Launcher by default uses all physical cores. Limit memory access to local memories on the Nth socket to avoid Non-Uniform Memory Access (NUMA). +3. Node id: [`--node_id`] + * Launcher by default uses all NUMA nodes. Limit memory access to local memories on the Nth Numa node to avoid Non-Uniform Memory Access (NUMA). -Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for a full list of tunable configuration of launcher. +Please refer to [Launch Script Usage Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for a full list of tunable configuration of launcher. And please refer to [Performance Tuning Guide](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) for more details. +### Launcher Core Pinning to Boost Performance of TorchServe Multi Worker Inference +When running [multi-worker inference](https://pytorch.org/serve/management_api.html#scale-workers) with Torchserve, launcher pin cores to workers to boost performance. Internally, launcher equally divides the number of cores by the number of workers such that each worker is pinned to assigned cores. Doing so avoids core overlap between workers which can signficantly boost performance for TorchServe multi-worker inference. For example, assume running 4 workers on a machine with Intel(R) Xeon(R) Platinum 8180 CPU, 2 sockets, 28 cores per socket, 2 threads per core. Launcher will bind worker 0 to cores 0-13, worker 1 to cores 14-27, worker 2 to cores 28-41, and worker 3 to cores 42-55. + +CPU usage is shown below. 4 main worker threads were launched, each launching 14 threads affinitized to the assigned physical cores. +![26](https://user-images.githubusercontent.com/93151422/170373651-fd8a0363-febf-4528-bbae-e1ddef119358.gif) + + +#### Scaling workers +Additionally when dynamically [scaling the number of workers](https://pytorch.org/serve/management_api.html#scale-workers), cores that were pinned to killed workers by the launcher could be left unutilized. To address this problem, launcher internally restarts the workers to re-distribute cores that were pinned to killed workers to the remaining, alive workers. This is taken care internally, so users do not have to worry about this. + +Continuing with the above example with 4 workers, assume killing workers 2 and 3. If cores were not re-distributed after the scale down, cores 28-55 would be left unutilized. Instead, launcher re-distributes cores 28-55 to workers 0 and 1 such that now worker 0 binds to cores 0-27 and worker 1 binds to cores 28-55.2 + +CPU usage is shown below. 4 main worker threads were initially launched. Then after scaling down the number of workers from 4 to 2, 2 main worker threads were launched, each launching 28 threads affinitized to the assigned physical cores. +![worker_scaling](https://user-images.githubusercontent.com/93151422/170374697-7497c2d5-4c17-421b-9993-1434d1f722f6.gif) + +2. Serving is interrupted for few seconds while re-distributing cores to scaled workers. + +Again, all it needs to be done to use TorchServe with launcher core pinning for multiple workers as well as scaling workers is to set its configuration in `config.properties`. + +Add the following lines in `config.properties` to use launcher with its default configuration. +``` +cpu_launcher_enable=true +``` ## Creating and Exporting INT8 model for IPEX -Intel Extension for PyTorch supports both eager and torchscript mode. In this section, we show how to deploy INT8 model for IPEX. +Intel® Extension for PyTorch* supports both eager and torchscript mode. In this section, we show how to deploy INT8 model for IPEX. Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/features/int8.md) for more details on Intel® Extension for PyTorch* optimizations for quantization. ### 1. Creating a serialized file First create `.pt` serialized file using IPEX INT8 inference. Here we show two examples with BERT and ResNet50. @@ -167,10 +191,18 @@ torch.jit.save(model, 'rn50_int8_jit.pt') ``` ### 2. Creating a Model Archive -Once the serialized file ( `.pt`) is created, it can be used with `torch-model-archiver` as ususal. Use the following command to package the model. +Once the serialized file ( `.pt`) is created, it can be used with `torch-model-archiver` as ususal. + +Use the following command to package `rn50_int8_jit.pt` into `rn50_ipex_int8.mar`. ``` torch-model-archiver --model-name rn50_ipex_int8 --version 1.0 --serialized-file rn50_int8_jit.pt --handler image_classifier ``` +Similarly, use the following command in the [Huggingface_Transformers directory](https://github.com/pytorch/serve/tree/master/examples/Huggingface_Transformers) to package `bert_int8_jit.pt` into `bert_ipex_int8.mar`. + +``` +torch-model-archiver --model-name bert_ipex_int8 --version 1.0 --serialized-file bert_int8_jit.pt --handler ./Transformer_handler_generalized.py --extra-files "./setup_config.json,./Seq_classification_artifacts/index_to_name.json" +``` + ### 3. Start TorchServe to serve the model Make sure to set `ipex_enable=true` in `config.properties`. Use the following command to start TorchServe with IPEX. ``` @@ -223,3 +255,88 @@ $ cat logs/model_log.log 2021-12-02 06:15:03,982 - __main__ - INFO - LD_PRELOAD=/lib/libiomp5.so ``` + +### Benchmarking with Launcher Core Pinning +As described previously in [TorchServe with Launcher](#torchserve-with-launcher), launcher core pinning boosts performance of multi-worker inference. We'll demonstrate launcher core pinning with TorchServe benchmark, but keep in mind that launcher core pinning is a generic feature applicable to any TorchServe multi-worker inference use casese. + +For example, assume running 4 workers +``` +python benchmark-ab.py --workers 4 +``` +on a machine with Intel(R) Xeon(R) Platinum 8180 CPU, 2 sockets, 28 cores per socket, 2 threads per core. Launcher will bind worker 0 to cores 0-13, worker 1 to cores 14-27, worker 2 to cores 28-41, and worker 3 to cores 42-55. + +All it needs to be done to use TorchServe with launcher's core pinning is to enable launcher in `config.properties`. + +Add the following lines to `config.properties` in the benchmark directory to use launcher's core pinning: +``` +cpu_launcher_enable=true +``` + +CPU usage is shown as below: +![launcher_core_pinning](https://user-images.githubusercontent.com/93151422/159063975-e7e8d4b0-e083-4733-bdb6-4d92bdc10556.gif) + +4 main worker threads were launched, then each launched a num_physical_cores/num_workers number (14) of threads affinitized to the assigned physical cores. + +

+$ cat logs/model_log.log
+2022-03-24 10:41:32,223 - __main__ - INFO - Use TCMalloc memory allocator
+2022-03-24 10:41:32,223 - __main__ - INFO - OMP_NUM_THREADS=14
+2022-03-24 10:41:32,223 - __main__ - INFO - Using Intel OpenMP
+2022-03-24 10:41:32,223 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
+2022-03-24 10:41:32,223 - __main__ - INFO - KMP_BLOCKTIME=1
+2022-03-24 10:41:32,223 - __main__ - INFO - LD_PRELOAD=/lib/libiomp5.so:/lib/libtcmalloc.so
+2022-03-24 10:41:32,223 - __main__ - INFO - numactl -C 0-13 -m 0 /bin/python -u /lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9000
+
+2022-03-24 10:49:03,760 - __main__ - INFO - Use TCMalloc memory allocator
+2022-03-24 10:49:03,761 - __main__ - INFO - OMP_NUM_THREADS=14
+2022-03-24 10:49:03,762 - __main__ - INFO - Using Intel OpenMP
+2022-03-24 10:49:03,762 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
+2022-03-24 10:49:03,762 - __main__ - INFO - KMP_BLOCKTIME=1
+2022-03-24 10:49:03,762 - __main__ - INFO - LD_PRELOAD=/lib/libiomp5.so:/lib/libtcmalloc.so
+2022-03-24 10:49:03,763 - __main__ - INFO - numactl -C 14-27 -m 0 /bin/python -u /lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9001
+
+2022-03-24 10:49:26,274 - __main__ - INFO - Use TCMalloc memory allocator
+2022-03-24 10:49:26,274 - __main__ - INFO - OMP_NUM_THREADS=14
+2022-03-24 10:49:26,274 - __main__ - INFO - Using Intel OpenMP
+2022-03-24 10:49:26,274 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
+2022-03-24 10:49:26,274 - __main__ - INFO - KMP_BLOCKTIME=1
+2022-03-24 10:49:26,274 - __main__ - INFO - LD_PRELOAD=/lib/libiomp5.so:/lib/libtcmalloc.so
+2022-03-24 10:49:26,274 - __main__ - INFO - numactl -C 28-41 -m 1 /bin/python -u /lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9002
+
+2022-03-24 10:49:42,975 - __main__ - INFO - Use TCMalloc memory allocator
+2022-03-24 10:49:42,975 - __main__ - INFO - OMP_NUM_THREADS=14
+2022-03-24 10:49:42,975 - __main__ - INFO - Using Intel OpenMP
+2022-03-24 10:49:42,975 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
+2022-03-24 10:49:42,975 - __main__ - INFO - KMP_BLOCKTIME=1
+2022-03-24 10:49:42,975 - __main__ - INFO - LD_PRELOAD=/lib/libiomp5.so:/lib/libtcmalloc.so
+2022-03-24 10:49:42,975 - __main__ - INFO - numactl -C 42-55 -m 1 /bin/python -u /lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9003
+
+ +## Performance Boost with IPEX and Launcher + +![pdt_perf](https://user-images.githubusercontent.com/93151422/159067306-dfd604e3-8c66-4365-91ae-c99f68d972d5.png) + + +Above shows performance improvement of Torchserve with IPEX and launcher on ResNet50 and BERT-base-uncased. Torchserve official [apache-bench benchmark](https://github.com/pytorch/serve/tree/master/benchmarks#benchmarking-with-apache-bench) on Amazon EC2 m6i.24xlarge was used to collect the results2. Add the following lines in ```config.properties``` to reproduce the results. Notice that launcher is configured such that a single instance uses all physical cores on a single socket to avoid cross socket communication and core overlap. + +``` +ipex_enable=true +cpu_launcher_enable=true +cpu_launcher_args=--node_id 0 --enable_jemalloc +``` +Use the following command to reproduce the results. +``` +python benchmark-ab.py --url {modelUrl} --input {inputPath} --concurrency 1 +``` + +For example, run the following command to reproduce latency performance of ResNet50 with data type of IPEX int8 and batch size of 1. Please refer to [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex) for steps to creating ```rn50_ipex_int8.mar``` file for ResNet50 with IPEX int8 data type. +``` +python benchmark-ab.py --url 'file:///model_store/rn50_ipex_int8.mar' --concurrency 1 +``` + +For example, run the following command to reproduce latency performance of BERT with data type of IPEX int8 and batch size of 1. Please refer to [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex) for steps to creating ```bert_ipex_int8.mar``` file for BERT with IPEX int8 data type. +``` +python benchmark-ab.py --url 'file:///model_store/bert_ipex_int8.mar' --input '../examples/Huggingface_Transformers/Seq_classification_artifacts/sample_text_captum_input.txt' --concurrency 1 +``` + +3. Amazon EC2 m6i.24xlarge was used for benchmarking purpose only. For multi-core instances, ipex optimizations automatically scale and leverage full instance resources. diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index 74c8f79ef6..c8f8b1d6a6 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -85,6 +85,17 @@ public int getNumRunningWorkers(ModelVersionName modelVersionName) { return numWorking; } + /** + * Checks if cpu_launcher is enabled and currentWorkers > 0 (i.e., not initializing workers). + * Workers are restarted so that when dynamically scaling the number of workers, cores that were + * pinned to killed workers by the launcher are not left unutilizied. If isRestart, workers are + * restarted to re-distribute cores that were pinned to killed workers to the remaining, alive + * workers. + */ + public boolean isLauncherRestartWorkers(int currentWorkers) { + return configManager.isCPULauncherEnabled() && currentWorkers > 0; + } + public CompletableFuture modelChanged( Model model, boolean isStartup, boolean isCleanUp) { synchronized (model.getModelVersionName()) { @@ -92,6 +103,8 @@ public CompletableFuture modelChanged( CompletableFuture future = new CompletableFuture<>(); int minWorker = model.getMinWorkers(); int maxWorker = model.getMaxWorkers(); + // Sets restartNumWorkers to the updated minWorker after scale up/down + int restartNumWorkers = minWorker; List threads; if (minWorker == 0) { threads = workers.remove(model.getModelVersionName()); @@ -109,6 +122,18 @@ public CompletableFuture modelChanged( } int currentWorkers = threads.size(); + boolean isRestartWorkers = isLauncherRestartWorkers(currentWorkers); + + if (isRestartWorkers) { + logger.warn( + "removing {} current thread(s) prior to restarting {} thread(s)", + currentWorkers, + minWorker); + // By setting maxWorker and minWorker to 0, removes all currentWorkers + maxWorker = 0; + minWorker = 0; + } + if (currentWorkers < minWorker) { addThreads(threads, model, minWorker - currentWorkers, future); } else { @@ -150,6 +175,13 @@ public CompletableFuture modelChanged( } future.complete(HttpURLConnection.HTTP_OK); } + + // After removing all currentWorkers, add back (i.e., restart) restartNumWorkers + if (isRestartWorkers) { + logger.warn("restarting {} thread(s)", restartNumWorkers); + addThreads(threads, model, restartNumWorkers, future); + } + if (!isStartup && !isSnapshotSaved && !isCleanUp && !model.isWorkflowModel()) { SnapshotManager.getInstance().saveSnapshot(); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index ac06cd0049..dbd4af358e 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -21,6 +21,7 @@ public class WorkerLifeCycle { private static final Logger logger = LoggerFactory.getLogger(WorkerLifeCycle.class); private ConfigManager configManager; + private ModelManager modelManager = ModelManager.getInstance(); private Model model; private int pid = -1; private Process process; @@ -30,10 +31,14 @@ public class WorkerLifeCycle { private ReaderThread errReader; private ReaderThread outReader; private String launcherArgs; + private int numWorker; + private int currNumRunningWorkers; public WorkerLifeCycle(ConfigManager configManager, Model model) { this.configManager = configManager; this.model = model; + this.numWorker = model.getMinWorkers(); + this.currNumRunningWorkers = modelManager.getNumRunningWorkers(model.getModelVersionName()); } public Process getProcess() { @@ -44,8 +49,6 @@ public ArrayList launcherArgsToList() { ArrayList arrlist = new ArrayList(); arrlist.add("-m"); arrlist.add("intel_extension_for_pytorch.cpu.launch"); - arrlist.add("--ninstance"); - arrlist.add("1"); if (launcherArgs != null && launcherArgs.length() > 1) { String[] argarray = launcherArgs.split(" "); for (int i = 0; i < argarray.length; i++) { @@ -99,6 +102,16 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup if (launcherAvailable) { ArrayList args = launcherArgsToList(); argl.addAll(args); + + // multi-worker core pinning + if (this.numWorker > 1) { + argl.add("--ninstances"); + argl.add(String.valueOf(this.numWorker)); + argl.add("--instance_idx"); + // instance_idx is 0-indexed + argl.add(String.valueOf(this.currNumRunningWorkers)); + } + } else { logger.warn( "CPU launcher is enabled but launcher is not available. Proceeding without launcher."); diff --git a/requirements/developer.txt b/requirements/developer.txt index 530a4489c4..6c1deae14e 100644 --- a/requirements/developer.txt +++ b/requirements/developer.txt @@ -1,4 +1,5 @@ -r common.txt +intel_extension_for_pytorch; sys_platform != 'win32' mock pytest pylint==2.6.0 diff --git a/test/config_ipex.properties b/test/config_ipex.properties new file mode 100644 index 0000000000..02a1528068 --- /dev/null +++ b/test/config_ipex.properties @@ -0,0 +1,6 @@ +inference_address=http://127.0.0.1:8080 +management_address=http://127.0.0.1:8081 + +ipex_enable=true +cpu_launcher_enable=true + diff --git a/test/pytest/test_example_intel_extension_for_pytorch.py b/test/pytest/test_example_intel_extension_for_pytorch.py new file mode 100644 index 0000000000..9b79eb3ae5 --- /dev/null +++ b/test/pytest/test_example_intel_extension_for_pytorch.py @@ -0,0 +1,186 @@ +import json +import os +import subprocess + +import pytest +import requests +import test_utils +from test_handler import run_inference_using_url_with_data + +REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") +snapshot_file_ipex = os.path.join(REPO_ROOT, "test/config_ipex.properties") +data_file_kitten = os.path.join(REPO_ROOT, "examples/image_classifier/kitten.jpg") +TS_LOG = "./logs/ts_log.log" + +MANAGEMENT_API = "http://localhost:8081" +INFERENCE_API = "http://localhost:8080" + +ipex_launcher_available = False +cmd = ["python", "-m", "intel_extension_for_pytorch.cpu.launch", "--no_python", "ls"] +r = subprocess.run(cmd) +if r.returncode == 0: + ipex_launcher_available = True + + +def setup_module(): + test_utils.torchserve_cleanup() + response = requests.get( + "https://torchserve.pytorch.org/mar_files/resnet-18.mar", allow_redirects=True + ) + open(test_utils.MODEL_STORE + "resnet-18.mar", "wb").write(response.content) + + +def setup_torchserve(): + if os.path.exists(TS_LOG): + os.remove(TS_LOG) + test_utils.start_torchserve( + test_utils.MODEL_STORE, snapshot_file_ipex, gen_mar=False + ) + + +def get_worker_affinity(num_workers, worker_idx): + from intel_extension_for_pytorch.cpu.launch import CPUinfo + + cpuinfo = CPUinfo() + num_cores = cpuinfo.physical_core_nums() + + num_cores_per_worker = num_cores // num_workers + start = worker_idx * num_cores_per_worker + end = (worker_idx + 1) * num_cores_per_worker - 1 + curr_worker_cores = [i for i in range(start, end + 1)] + affinity = "numactl -C {}-{}".format(str(start), str(end)) + affinity += " -m {}".format( + ",".join( + [str(numa_id) for numa_id in cpuinfo.numa_aware_check(curr_worker_cores)] + ) + ) + return affinity + + +def run_inference_with_core_pinning(): + files = { + "data": (data_file_kitten, open(data_file_kitten, "rb")), + } + response = run_inference_using_url_with_data( + "http://localhost:8080/predictions/resnet-18", files + ) + return response + + +def scale_workers_with_core_pinning(scaled_num_workers): + params = (("min_worker", str(scaled_num_workers)),) + requests.put("http://localhost:8081/models/resnet-18", params=params) + response = requests.get("http://localhost:8081/models/resnet-18") + return response + + +@pytest.mark.skipif( + not ipex_launcher_available, + reason="Make sure intel-extension-for-pytorch is installed", +) +def test_single_worker_affinity(): + num_workers = 1 + worker_idx = 0 + setup_torchserve() + requests.post( + "http://localhost:8081/models?initial_workers={}&synchronous=true&url=resnet-18.mar".format( + num_workers + ) + ) + + response = run_inference_with_core_pinning() + assert ( + response.status_code == 200 + ), "single-worker inference with core pinning failed" + + affinity = get_worker_affinity(num_workers, worker_idx) + assert affinity in open(TS_LOG).read(), "workers are not correctly pinned to cores" + + +@pytest.mark.skipif( + not ipex_launcher_available, + reason="Make sure intel-extension-for-pytorch is installed", +) +def test_multi_worker_affinity(): + num_workers = 4 + setup_torchserve() + requests.post( + "http://localhost:8081/models?initial_workers={}&synchronous=true&url=resnet-18.mar".format( + num_workers + ) + ) + + response = run_inference_with_core_pinning() + assert ( + response.status_code == 200 + ), "multi-worker inference with core pinning failed" + + for worker_idx in range(num_workers): + curr_worker_affinity = get_worker_affinity(num_workers, worker_idx) + assert ( + curr_worker_affinity in open(TS_LOG).read() + ), "workers are not correctly pinned to cores" + + +@pytest.mark.skipif( + not ipex_launcher_available, + reason="Make sure intel-extension-for-pytorch is installed", +) +def test_worker_scale_up_affinity(): + initial_num_workers = 2 + setup_torchserve() + requests.post( + "http://localhost:8081/models?initial_workers={}&synchronous=true&url=resnet-18.mar".format( + initial_num_workers + ) + ) + + scaled_up_num_workers = 4 + response = scale_workers_with_core_pinning(scaled_up_num_workers) + resnet18_list = json.loads(response.content) + assert ( + len(resnet18_list[0]["workers"]) == scaled_up_num_workers + ), "workers failed to scale up with core pinning" + + response = run_inference_with_core_pinning() + assert ( + response.status_code == 200 + ), "scaled up workers inference with core pinning failed" + + for worker_idx in range(scaled_up_num_workers): + curr_worker_affinity = get_worker_affinity(scaled_up_num_workers, worker_idx) + assert ( + curr_worker_affinity in open(TS_LOG).read() + ), "workers are not correctly pinned to cores" + + +@pytest.mark.skipif( + not ipex_launcher_available, + reason="Make sure intel-extension-for-pytorch is installed", +) +def test_worker_scale_down_affinity(): + initial_num_workers = 4 + setup_torchserve() + requests.post( + "http://localhost:8081/models?initial_workers={}&synchronous=true&url=resnet-18.mar".format( + initial_num_workers + ) + ) + + scaled_down_num_workers = 2 + response = scale_workers_with_core_pinning(scaled_down_num_workers) + resnet18_list = json.loads(response.content) + assert ( + len(resnet18_list[0]["workers"]) == scaled_down_num_workers + ), "workers failed to scale down with core pinning" + + response = run_inference_with_core_pinning() + assert ( + response.status_code == 200 + ), "scaled down workers inference with core pinning failed" + + for worker_idx in range(scaled_down_num_workers): + curr_worker_affinity = get_worker_affinity(scaled_down_num_workers, worker_idx) + assert ( + curr_worker_affinity in open(TS_LOG).read() + ), "workers are not correctly pinned to cores"